mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-12 13:02:12 +00:00
add warmup_decode
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
fd70ad703e
commit
ba7a131e04
@ -1494,6 +1494,10 @@ class FlashCausalLM(Model):
|
|||||||
for seqlen in [32, 64, 128, 256, 512, 1024]:
|
for seqlen in [32, 64, 128, 256, 512, 1024]:
|
||||||
self.warmup_prefill(seqlen, bs)
|
self.warmup_prefill(seqlen, bs)
|
||||||
|
|
||||||
|
for bs in [1, 2, 4, 8]:
|
||||||
|
for block_num in [1, 2, 4, 8, 16]:
|
||||||
|
self.warmup_decode(bs, block_num * bs)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||||
@ -1544,6 +1548,57 @@ class FlashCausalLM(Model):
|
|||||||
hpu_attention_meta=None,
|
hpu_attention_meta=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def warmup_decode(self, bs: int, block_num: int):
|
||||||
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
block_num, dtype=torch.int32, device=self.device
|
||||||
|
).reshape(bs, -1)
|
||||||
|
slots = []
|
||||||
|
past_len = (
|
||||||
|
len(block_tables[0]) * BLOCK_SIZE - 3
|
||||||
|
) # for decode, we only need to pass the past token
|
||||||
|
# fetch the last blocked to warmup block num
|
||||||
|
for i in range(bs):
|
||||||
|
slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2)
|
||||||
|
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
|
cache_lengths_tensor = (
|
||||||
|
torch.ones(bs, dtype=torch.int32, device=self.device) * past_len
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||||
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
)
|
||||||
|
hpu_attention_meta = prepare_for_decode(
|
||||||
|
self.dtype,
|
||||||
|
self.use_contiguous_pa,
|
||||||
|
self.device,
|
||||||
|
slots,
|
||||||
|
block_tables,
|
||||||
|
bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph
|
||||||
|
slots=slots,
|
||||||
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
adapter_data=None,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
Loading…
Reference in New Issue
Block a user