From 57f55fe8346bf9dbcf2c598eb610b2966249cd07 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:17:18 +0200 Subject: [PATCH] idk at this point --- .../models/flash_causal_lm.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b7202c04..05bad924 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -758,11 +758,12 @@ class FlashCausalLMBatch(Batch): input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids + slots[slots_start_index:slots_end_index] = batch.slots slot_indices[start_index:end_index] = ( batch.slot_indices + cumulative_slots ) input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor - slots[slots_start_index:slots_end_index] = batch.slots + cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Copy over adapter indices adapter_start_index = cumulative_adapter_indices_size @@ -779,7 +780,6 @@ class FlashCausalLMBatch(Batch): batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor # Update cumulative_slots += len(batch.slots) @@ -1614,13 +1614,12 @@ class FlashCausalLM(Model): input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, ): - max_k = (input_lengths + cache_lengths_tensor).max().item() seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, - max_k=max_k, + max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -1852,46 +1851,44 @@ class FlashCausalLM(Model): request_was_prefilling, request_is_prefilling, ) in enumerate(iterator): - # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length + if prefill and finished_prefilling: + # Indexing metadata + _start_index = cumulative_length + end_index = cumulative_length + input_length - if prefill: + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] + + # Used to gather prefill logprobs + # Copy batch.all_input_ids_tensor to prefill_token_indices + if request.prefill_logprobs and request_was_prefilling: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] - if finished_prefilling: - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 - ] - - # Used to gather prefill logprobs - # Copy batch.all_input_ids_tensor to prefill_token_indices - if request.prefill_logprobs and request_was_prefilling: - # Logprobs generated by the model are for the next token - # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] - if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index] = ids - else: - # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = ids + # Logprobs generated by the model are for the next token + # So we need to translate the id tensor by 1 + ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] + if len(batch) > 1: + prefill_tokens_indices[out_start_index : out_end_index] = ids + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = ids if not request_is_prefilling: # Only save tokens if we are done prefilling for this request for j in range(n_accepted_ids): batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( - next_input_ids[index] + next_input_ids[index + j] ) - index += 1 - + index += n_accepted_ids cumulative_length += input_length # Update values