mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
warmup decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ba7a131e04
commit
7900be5ac3
@ -1552,15 +1552,15 @@ class FlashCausalLM(Model):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
block_tables = torch.arange(
|
||||
block_num, dtype=torch.int32, device=self.device
|
||||
start=1, end=block_num + 1, dtype=torch.int32, device=self.device
|
||||
).reshape(bs, -1)
|
||||
slots = []
|
||||
past_len = (
|
||||
len(block_tables[0]) * BLOCK_SIZE - 3
|
||||
len(block_tables[0]) * BLOCK_SIZE - 1
|
||||
) # for decode, we only need to pass the past token
|
||||
# fetch the last blocked to warmup block num
|
||||
for i in range(bs):
|
||||
slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2)
|
||||
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
@ -1575,12 +1575,17 @@ class FlashCausalLM(Model):
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
|
||||
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
|
||||
block_tables_valid = []
|
||||
for i, bt in enumerate(block_tables.tolist()):
|
||||
block_tables_valid.append(bt[0 : block_num[i]])
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.device,
|
||||
slots,
|
||||
block_tables,
|
||||
block_tables_valid,
|
||||
bs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user