From ea4b739a9f9f903a13200ba97c073f2bbda11e56 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 7 Oct 2024 17:12:31 +0200 Subject: [PATCH] fix prefill logprobs --- server/text_generation_server/models/flash_causal_lm.py | 9 ++++++--- server/text_generation_server/server.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 9a34dfc5..4e9f9c66 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1757,6 +1757,7 @@ class FlashCausalLM(Model): finished_prefilling = True next_chunk_lengths = [] + current_prefilling_mask = batch.prefilling_mask if prefill: if get_support_chunking(): next_prefilling_mask = [] @@ -1998,6 +1999,7 @@ class FlashCausalLM(Model): batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + current_prefilling_mask, batch.prefilling_mask, accepted_ids, batch_top_token_ids, @@ -2021,7 +2023,8 @@ class FlashCausalLM(Model): do_sample, seed, top_n_tokens, - request_prefilling, + request_was_prefilling, + request_is_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, @@ -2032,7 +2035,7 @@ class FlashCausalLM(Model): # this state to be stable if request.id % self.world_size == self.rank: # Prefill - if request_prefilling and request.prefill_logprobs: + if request_was_prefilling and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] @@ -2072,7 +2075,7 @@ class FlashCausalLM(Model): batch.prefill_logprob_tokens[i] = None # If it is, the tokens we decoded should be ignored - if request_prefilling: + if request_is_prefilling: # Make sure that we do not stop as even though this request did not create a token, it is still # processing stopped = False diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index da85d19d..aef00fb5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -165,7 +165,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): f"Batch ID {request.cached_batch.id} not found in cache." ) start_concat = time.time_ns() - batch = self.model.batch_type.concatenate([batch, cached_batch]) + batch = self.model.batch_type.concatenate([cached_batch, batch]) concat_ns = time.time_ns() - start_concat generations, next_batch, timings = self.model.generate_token(batch)