From ba7a131e041519326767953dae5d956022423593 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 26 Mar 2025 17:39:26 -0700 Subject: [PATCH] add warmup_decode Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index b26184e4..cb879c9c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1494,6 +1494,10 @@ class FlashCausalLM(Model): for seqlen in [32, 64, 128, 256, 512, 1024]: self.warmup_prefill(seqlen, bs) + for bs in [1, 2, 4, 8]: + for block_num in [1, 2, 4, 8, 16]: + self.warmup_decode(bs, block_num * bs) + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): @@ -1544,6 +1548,57 @@ class FlashCausalLM(Model): hpu_attention_meta=None, ) + def warmup_decode(self, bs: int, block_num: int): + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + block_num, dtype=torch.int32, device=self.device + ).reshape(bs, -1) + slots = [] + past_len = ( + len(block_tables[0]) * BLOCK_SIZE - 3 + ) # for decode, we only need to pass the past token + # fetch the last blocked to warmup block num + for i in range(bs): + slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2) + slots = torch.tensor(slots, dtype=torch.int64, device=self.device) + + input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) + cache_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * past_len + ) + cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) + torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) + + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + ) + hpu_attention_meta = prepare_for_decode( + self.dtype, + self.use_contiguous_pa, + self.device, + slots, + block_tables, + bs, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph + slots=slots, + seqlen=trim_seqlen_metadata(seqlen), + prefill_cache_indices=None, + lm_head_indices=None, + adapter_data=None, + hpu_attention_meta=hpu_attention_meta, + ) + def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: