mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-29 22:12:07 +00:00
warmup decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ba7a131e04
commit
7900be5ac3
@ -1552,15 +1552,15 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(
|
||||||
block_num, dtype=torch.int32, device=self.device
|
start=1, end=block_num + 1, dtype=torch.int32, device=self.device
|
||||||
).reshape(bs, -1)
|
).reshape(bs, -1)
|
||||||
slots = []
|
slots = []
|
||||||
past_len = (
|
past_len = (
|
||||||
len(block_tables[0]) * BLOCK_SIZE - 3
|
len(block_tables[0]) * BLOCK_SIZE - 1
|
||||||
) # for decode, we only need to pass the past token
|
) # for decode, we only need to pass the past token
|
||||||
# fetch the last blocked to warmup block num
|
# fetch the last blocked to warmup block num
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2)
|
slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1)
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
@ -1575,12 +1575,17 @@ class FlashCausalLM(Model):
|
|||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
block_num = cache_lengths_tensor // BLOCK_SIZE + 1
|
||||||
|
block_tables_valid = []
|
||||||
|
for i, bt in enumerate(block_tables.tolist()):
|
||||||
|
block_tables_valid.append(bt[0 : block_num[i]])
|
||||||
hpu_attention_meta = prepare_for_decode(
|
hpu_attention_meta = prepare_for_decode(
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.use_contiguous_pa,
|
self.use_contiguous_pa,
|
||||||
self.device,
|
self.device,
|
||||||
slots,
|
slots,
|
||||||
block_tables,
|
block_tables_valid,
|
||||||
bs,
|
bs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user