From ba049c9d49465a362d5ecfa4a58366bcc635696d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 11 Apr 2025 06:10:17 -0700 Subject: [PATCH] improve performance Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 52 ++++++++++++------- .../models/flash_vlm_causal_lm.py | 4 +- .../models/mllama_causal_lm.py | 4 +- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index d4ff3f707..5c7b8bc01 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -89,7 +89,7 @@ def get_sliding_windows() -> int: def prepare_for_decode( - dtype, use_contiguous_pa, device, slot, block_tables, batch_size, bucketing_ctx + dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx ): # Prepare values if we need to continue decoding # need for HPUPagedAttentionMetadata preparation @@ -105,7 +105,7 @@ def prepare_for_decode( padding = target_len - input_len return input + [v] * padding - last_block_usage = slot % BLOCK_SIZE + 1 + last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots] block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] block_usage = [ [BLOCK_SIZE] * (len(bt) - 1) + [lbu] @@ -964,7 +964,7 @@ class FlashCausalLMBatch(Batch): ) def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx): - block_num = self.cache_lengths_tensor // BLOCK_SIZE + 1 + block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths] block_tables = [] for i, bt in enumerate(self.block_tables): block_tables.append(bt[0 : block_num[i]]) @@ -984,7 +984,7 @@ class FlashCausalLMBatch(Batch): dtype, use_contiguous_pa, self.block_tables_tensor.device, - slots, + slots.cpu(), block_tables, padded_bs, bucketing_ctx, @@ -1616,7 +1616,6 @@ class FlashCausalLM(Model): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -1641,13 +1640,14 @@ class FlashCausalLM(Model): batch_size, bucketing_ctx=None, ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), lm_head_indices=None, adapter_data=None, @@ -1866,8 +1866,8 @@ class FlashCausalLM(Model): for i in range(len(batch)): batch.all_input_ids_tensor[ i, - batch.cache_lengths_tensor[i] - + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + batch.cache_lengths[i] + + batch.input_lengths[i] : batch.cache_lengths[i] + batch.input_lengths[i] + accepted_ids[i], ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] @@ -1915,14 +1915,36 @@ class FlashCausalLM(Model): } ) idx = len(prev_batches) - 1 + if batch.speculative_logits is not None: + accepted_ids_cpu = accepted_ids.cpu() for req_idx, req in enumerate(batch.requests): + new_input_length = 1 + if batch.speculative_logits is not None: + new_cache_length = ( + batch.cache_lengths[req_idx] + + batch.input_lengths[req_idx] + + accepted_ids_cpu[req_idx] + - 1 + ) + else: + new_cache_length = ( + batch.cache_lengths[req_idx] + batch.input_lengths[req_idx] + ) + batch.cache_lengths[req_idx] = new_cache_length + batch.max_input_length = max( + batch.max_input_length, new_input_length + ) + batch.input_lengths[req_idx] = new_input_length + current_length = new_cache_length + new_input_length + batch.max_current_length = max( + batch.max_current_length, current_length + ) + requests_to_generate.append( { "idx": idx, "request_id": req.id, - "cache_length": batch.cache_lengths[req_idx], - "input_length": batch.input_lengths[req_idx], "prefix_offset": batch.prefix_offsets[req_idx], "read_offset": batch.read_offsets[req_idx], "stopping_criteria": batch.stopping_criterias[req_idx], @@ -2029,8 +2051,6 @@ class FlashCausalLM(Model): for i, req_data in enumerate(requests_to_generate): idx = req_data["idx"] request_id = req_data["request_id"] - cache_length = req_data["cache_length"] - input_length = req_data["input_length"] prefix_offset = req_data["prefix_offset"] read_offset = req_data["read_offset"] stopping_criteria = req_data["stopping_criteria"] @@ -2041,9 +2061,6 @@ class FlashCausalLM(Model): n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]] top_token_ids = req_data["top_token_ids"] top_token_logprobs = req_data["top_token_logprobs"] - - new_input_length = 1 - new_cache_length = cache_length + input_length + n_accepted_ids - 1 # Append next token to all tokens next_token_texts = [] left = 0 @@ -2159,11 +2176,6 @@ class FlashCausalLM(Model): # Update values indexs[idx] += n_accepted_ids idx_accept_ids[idx] += 1 - batch.cache_lengths[i] = new_cache_length - batch.max_input_length = max(batch.max_input_length, new_input_length) - batch.input_lengths[i] = new_input_length - current_length = new_cache_length + new_input_length - batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index cdda751aa..c885816b8 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -408,7 +408,6 @@ class FlashVlmCausalLM(FlashCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -433,13 +432,14 @@ class FlashVlmCausalLM(FlashCausalLM): batch_size, bucketing_ctx=None, ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index d21cc39d2..6a0661851 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -247,7 +247,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): block_tables.append(block_array) past_len.append(blocks[i] * BLOCK_SIZE - 1) start_idx += blocks[i] - slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.tensor( past_len, dtype=torch.int32, device=self.device @@ -279,12 +278,13 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): indices, cross_attention_len = generate_cross_attention_states( cross_attention_states, image_indices, seqlen, 1, False ) + slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device) self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - slots=slots, + slots=slots_tensor, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=hpu_attention_meta, lm_head_indices=None,