diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index cb879c9c..e032242c 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1552,15 +1552,15 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.arange(bs, dtype=torch.int32, device=self.device) block_tables = torch.arange( - block_num, dtype=torch.int32, device=self.device + start=1, end=block_num + 1, dtype=torch.int32, device=self.device ).reshape(bs, -1) slots = [] past_len = ( - len(block_tables[0]) * BLOCK_SIZE - 3 + len(block_tables[0]) * BLOCK_SIZE - 1 ) # for decode, we only need to pass the past token # fetch the last blocked to warmup block num for i in range(bs): - slots.extend(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 2) + slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) slots = torch.tensor(slots, dtype=torch.int64, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) @@ -1575,12 +1575,17 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) + + block_num = cache_lengths_tensor // BLOCK_SIZE + 1 + block_tables_valid = [] + for i, bt in enumerate(block_tables.tolist()): + block_tables_valid.append(bt[0 : block_num[i]]) hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, self.device, slots, - block_tables, + block_tables_valid, bs, )