warmup prefill

remove model where pageattn is not used, set block table to None since it's not used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-25 22:21:44 -07:00
parent 69773767c5
commit fd70ad703e
5 changed files with 117 additions and 115 deletions

View File

@ -92,7 +92,6 @@ try:
from text_generation_server.models.custom_modeling.flash_phi_modeling import ( from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM, FlashPhiForCausalLM,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import ( from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration, FlashMllamaForConditionalGeneration,
@ -144,7 +143,6 @@ except ImportError as e:
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM) __all__.append(FlashCausalLM)
__all__.append(IdeficsCausalLM)
class ModelType(enum.Enum): class ModelType(enum.Enum):
@ -301,12 +299,6 @@ class ModelType(enum.Enum):
"name": "Gptj", "name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b", "url": "https://huggingface.co/EleutherAI/gpt-j-6b",
} }
IDEFICS = {
"type": "idefics",
"name": "Idefics",
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
"multimodal": True,
}
MLLAMA = { MLLAMA = {
"type": "mllama", "type": "mllama",
"name": "Mllama", "name": "Mllama",
@ -733,15 +725,6 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif model_type == IDEFICS:
return IdeficsCausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == QWEN2_VL: elif model_type == QWEN2_VL:
return FlashVlmCausalLM( return FlashVlmCausalLM(
model_id=model_id, model_id=model_id,

View File

@ -69,6 +69,8 @@ from text_generation_server.utils.import_utils import (
import vllm_hpu_extension.environment as environment import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.ops import batch2block, block2batch
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -86,6 +88,78 @@ def get_sliding_windows() -> int:
return SLIDING_WINDOW return SLIDING_WINDOW
def prepare_for_decode(
dtype, use_contiguous_pa, device, slot, block_tables, batch_size
):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
def flatten(in_list):
return list(itertools.chain(*in_list))
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def pad_list(input, k, v):
input_len = len(input)
target_len = (input_len + k - 1) // k * k
padding = target_len - input_len
return input + [v] * padding
last_block_usage = slot % BLOCK_SIZE + 1
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt
]
block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
if use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
# block_bucket_size)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
block_list = gather_list(block_list, indices, 0)
block_groups = gather_list(block_groups, indices, -1)
block_usage = gather_list(block_usage, indices, 1)
else:
block_bucket_size = len(block_list)
block_list = pad_list(block_list, block_bucket_size, 0)
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
ones = torch.ones(
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list,
block_groups=block_groups,
block_usage=block_usage,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
block_scales=block_scales,
)
)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
batch_id: int batch_id: int
@ -879,83 +953,18 @@ class FlashCausalLMBatch(Batch):
) )
def prepare_for_decode(self, dtype, use_contiguous_pa): def prepare_for_decode(self, dtype, use_contiguous_pa):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
import itertools
from vllm_hpu_extension.ops import batch2block, block2batch
def flatten(in_list):
return list(itertools.chain(*in_list))
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def pad_list(input, k, v):
input_len = len(input)
target_len = (input_len + k - 1) // k * k
padding = target_len - input_len
return input + [v] * padding
device = self.block_tables_tensor.device
last_block_usage = self.slots[self.slot_indices] % BLOCK_SIZE + 1
block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1
block_tables = [] block_tables = []
for i, bt in enumerate(self.block_tables): for i, bt in enumerate(self.block_tables):
block_tables.append(bt[0 : block_num[i]]) block_tables.append(bt[0 : block_num[i]])
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt
]
block_list = flatten(block_tables) self.hpu_attn_meta = prepare_for_decode(
block_groups = flatten(block_groups) dtype,
block_usage = flatten(block_usage) use_contiguous_pa,
batch = self.input_ids.size(0) self.block_tables_tensor.device,
self.slots[self.slot_indices],
assert len(block_list) == len(block_groups) block_tables,
assert len(block_list) == len(block_usage) self.input_ids.size(0),
if use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
# block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
# block_bucket_size)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
block_list = gather_list(block_list, indices, 0)
block_groups = gather_list(block_groups, indices, -1)
block_usage = gather_list(block_usage, indices, 1)
else:
block_bucket_size = len(block_list)
block_list = pad_list(block_list, block_bucket_size, 0)
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_list = torch.tensor(block_list, dtype=torch.int, device=device)
block_groups = torch.tensor(block_groups, dtype=torch.int, device=device)
block_usage = torch.tensor(block_usage, dtype=dtype, device=device)
block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch)
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(
0
)
mask = mask >= block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
ones = torch.ones(
(block_mapping.size(0),), device=device, dtype=block_mapping.dtype
)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
self.hpu_attn_meta = trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list,
block_groups=block_groups,
block_usage=block_usage,
block_mapping=block_mapping.to(dtype),
attn_bias=attn_bias,
block_scales=block_scales,
)
) )
def prepare_for_prefill(self): def prepare_for_prefill(self):
@ -1481,32 +1490,44 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
for bs in [1, 2, 4, 8]:
for seqlen in [32, 64, 128, 256, 512, 1024]:
self.warmup_prefill(seqlen, bs)
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def tunableop_warmup(self, seqlen: int, max_bt: int): def warmup_prefill(self, prompt_len: int, bs: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) input_ids = torch.zeros(
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) prompt_len, dtype=torch.int64, device=self.device
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) ).repeat(bs)
position_ids = torch.arange(
# Dummy value, some models (starcoder2) don't accept `None`. prompt_len, dtype=torch.int32, device=self.device
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) ).repeat(bs)
cache_lengths_tensor = torch.zeros( max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
seqlen, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
block_tables = torch.arange( block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device max_bt, dtype=torch.int32, device=self.device
).repeat(seqlen) ).reshape(bs, -1)
block_tables = block_tables.reshape((seqlen, max_bt)) slot_acc = []
for i in range(bs):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
)
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
) )
lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1514,11 +1535,13 @@ class FlashCausalLM(Model):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
seqlen=seqlen,
slots=slots, slots=slots,
lm_head_indices=None, seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=lm_head_indices,
adapter_data=None,
hpu_attention_meta=None,
) )
def forward( def forward(
@ -1606,7 +1629,7 @@ class FlashCausalLM(Model):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=None,
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
@ -1637,9 +1660,7 @@ class FlashCausalLM(Model):
batch.prepare_for_prefill() batch.prepare_for_prefill()
else: else:
batch.prepare_for_decode(self.dtype, self.use_contiguous_pa) batch.prepare_for_decode(self.dtype, self.use_contiguous_pa)
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None: if batch.speculative_ids is not None:

View File

@ -462,7 +462,7 @@ class FlashVlmCausalLM(FlashCausalLM):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,

View File

@ -288,7 +288,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,

View File

@ -33,13 +33,11 @@ try:
from text_generation_server.models.flash_vlm_causal_lm import ( from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch, FlashVlmCausalLMBatch,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
VLM_BATCH_TYPES = { VLM_BATCH_TYPES = {
PaliGemmaBatch, PaliGemmaBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
FlashVlmCausalLMBatch, FlashVlmCausalLMBatch,
IdeficsCausalLMBatch,
FlashMllamaCausalLMBatch, FlashMllamaCausalLMBatch,
} }
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):