warmup decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-26 20:19:13 -07:00
parent ba7a131e04
commit 7900be5ac3

View File

@ -1552,15 +1552,15 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
block_num, dtype=torch.int32, device=self.device
start=1, end=block_num + 1, dtype=torch.int32, device=self.device
).reshape(bs, -1)
slots = []
past_len = (
len(block_tables[0]) * BLOCK_SIZE - 3
len(block_tables[0]) * BLOCK_SIZE - 1
) # for decode, we only need to pass the past token
# fetch the last blocked to warmup block num
for i in range(bs):
slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2)
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
@ -1575,12 +1575,17 @@ class FlashCausalLM(Model):
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
block_tables_valid = []
for i, bt in enumerate(block_tables.tolist()):
block_tables_valid.append(bt[0 : block_num[i]])
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
block_tables_valid,
bs,
)