fix prefill logprobs

This commit is contained in:
OlivierDehaene 2024-10-07 17:12:31 +02:00
parent 3924b87a04
commit ea4b739a9f
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 7 additions and 4 deletions

View File

@ -1757,6 +1757,7 @@ class FlashCausalLM(Model):
finished_prefilling = True finished_prefilling = True
next_chunk_lengths = [] next_chunk_lengths = []
current_prefilling_mask = batch.prefilling_mask
if prefill: if prefill:
if get_support_chunking(): if get_support_chunking():
next_prefilling_mask = [] next_prefilling_mask = []
@ -1998,6 +1999,7 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask, batch.prefilling_mask,
accepted_ids, accepted_ids,
batch_top_token_ids, batch_top_token_ids,
@ -2021,7 +2023,8 @@ class FlashCausalLM(Model):
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
request_prefilling, request_was_prefilling,
request_is_prefilling,
n_accepted_ids, n_accepted_ids,
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
@ -2032,7 +2035,7 @@ class FlashCausalLM(Model):
# this state to be stable # this state to be stable
if request.id % self.world_size == self.rank: if request.id % self.world_size == self.rank:
# Prefill # Prefill
if request_prefilling and request.prefill_logprobs: if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i] out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
@ -2072,7 +2075,7 @@ class FlashCausalLM(Model):
batch.prefill_logprob_tokens[i] = None batch.prefill_logprob_tokens[i] = None
# If it is, the tokens we decoded should be ignored # If it is, the tokens we decoded should be ignored
if request_prefilling: if request_is_prefilling:
# Make sure that we do not stop as even though this request did not create a token, it is still # Make sure that we do not stop as even though this request did not create a token, it is still
# processing # processing
stopped = False stopped = False

View File

@ -165,7 +165,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
f"Batch ID {request.cached_batch.id} not found in cache." f"Batch ID {request.cached_batch.id} not found in cache."
) )
start_concat = time.time_ns() start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([batch, cached_batch]) batch = self.model.batch_type.concatenate([cached_batch, batch])
concat_ns = time.time_ns() - start_concat concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)