From d361197aab814aaa7e24b469b607b1c78090be72 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:04:06 +0200 Subject: [PATCH] omfg --- .../models/flash_causal_lm.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7ebe3dea..98de8c79 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -956,11 +956,13 @@ class FlashCausalLMBatch(Batch): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append(torch.arange( - cumulative_length, - cumulative_length + input_length, - dtype=torch.int64 - )) + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) @@ -1875,9 +1877,11 @@ class FlashCausalLM(Model): # Logprobs generated by the model are for the next token # So we need to translate the id tensor by 1 - ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] + ids = batch.all_input_ids_tensor[ + i, cache_length + 1 : cache_length + input_length + 1 + ] if len(batch) > 1: - prefill_tokens_indices[out_start_index : out_end_index] = ids + prefill_tokens_indices[out_start_index:out_end_index] = ids else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = ids @@ -2014,6 +2018,12 @@ class FlashCausalLM(Model): top_token_ids, top_token_logprobs, ) in enumerate(iterator): + if all_input_ids[:2] == [1986, 374] and not request_is_prefilling: + log_master( + logger.info, + f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}", + ) + # Compute logprobs first as, even though we might skip the token, # it can still be required to compute the logprobs # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need @@ -2046,7 +2056,7 @@ class FlashCausalLM(Model): cache_length + 1 ) + request_prefill_logprobs prefill_token_ids = ( - all_input_ids[:cache_length + 1] + prefill_token_ids + all_input_ids[: cache_length + 1] + prefill_token_ids ) prefill_texts = self.tokenizer.batch_decode( @@ -2114,7 +2124,6 @@ class FlashCausalLM(Model): _next_token_logprobs = next_token_logprobs[ index : index + n_accepted_ids - left ] - index += n_accepted_ids # Shard generations # All generations will be appended in the rust sharded client @@ -2189,6 +2198,7 @@ class FlashCausalLM(Model): ) # Update values + index += n_accepted_ids current_cache_length = cache_length + input_length batch.cache_lengths[i] = current_cache_length current_input_length = new_input_length