mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-12 13:02:12 +00:00
add warmup_decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
fd70ad703e
commit
ba7a131e04
@ -1494,6 +1494,10 @@ class FlashCausalLM(Model):
|
||||
for seqlen in [32, 64, 128, 256, 512, 1024]:
|
||||
self.warmup_prefill(seqlen, bs)
|
||||
|
||||
for bs in [1, 2, 4, 8]:
|
||||
for block_num in [1, 2, 4, 8, 16]:
|
||||
self.warmup_decode(bs, block_num * bs)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||
@ -1544,6 +1548,57 @@ class FlashCausalLM(Model):
|
||||
hpu_attention_meta=None,
|
||||
)
|
||||
|
||||
def warmup_decode(self, bs: int, block_num: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
block_tables = torch.arange(
|
||||
block_num, dtype=torch.int32, device=self.device
|
||||
).reshape(bs, -1)
|
||||
slots = []
|
||||
past_len = (
|
||||
len(block_tables[0]) * BLOCK_SIZE - 3
|
||||
) # for decode, we only need to pass the past token
|
||||
# fetch the last blocked to warmup block num
|
||||
for i in range(bs):
|
||||
slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2)
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
|
||||
)
|
||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
hpu_attention_meta = prepare_for_decode(
|
||||
self.dtype,
|
||||
self.use_contiguous_pa,
|
||||
self.device,
|
||||
slots,
|
||||
block_tables,
|
||||
bs,
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
Loading…
Reference in New Issue
Block a user