add warmup_decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-26 17:39:26 -07:00
parent fd70ad703e
commit ba7a131e04

View File

@ -1494,6 +1494,10 @@ class FlashCausalLM(Model):
for seqlen in [32, 64, 128, 256, 512, 1024]:
self.warmup_prefill(seqlen, bs)
for bs in [1, 2, 4, 8]:
for block_num in [1, 2, 4, 8, 16]:
self.warmup_decode(bs, block_num * bs)
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def warmup_prefill(self, prompt_len: int, bs: int):
@ -1544,6 +1548,57 @@ class FlashCausalLM(Model):
hpu_attention_meta=None,
)
def warmup_decode(self, bs: int, block_num: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
block_num, dtype=torch.int32, device=self.device
).reshape(bs, -1)
slots = []
past_len = (
len(block_tables[0]) * BLOCK_SIZE - 3
) # for decode, we only need to pass the past token
# fetch the last blocked to warmup block num
for i in range(bs):
slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2)
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
cache_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
)
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
bs,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
prefill_cache_indices=None,
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: