fix prefill logprobs

This commit is contained in:
OlivierDehaene 2024-10-07 17:12:31 +02:00
parent 3924b87a04
commit ea4b739a9f
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 7 additions and 4 deletions

View File

@ -1757,6 +1757,7 @@ class FlashCausalLM(Model):
finished_prefilling = True
next_chunk_lengths = []
current_prefilling_mask = batch.prefilling_mask
if prefill:
if get_support_chunking():
next_prefilling_mask = []
@ -1998,6 +1999,7 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask,
accepted_ids,
batch_top_token_ids,
@ -2021,7 +2023,8 @@ class FlashCausalLM(Model):
do_sample,
seed,
top_n_tokens,
request_prefilling,
request_was_prefilling,
request_is_prefilling,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
@ -2032,7 +2035,7 @@ class FlashCausalLM(Model):
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if request_prefilling and request.prefill_logprobs:
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
@ -2072,7 +2075,7 @@ class FlashCausalLM(Model):
batch.prefill_logprob_tokens[i] = None
# If it is, the tokens we decoded should be ignored
if request_prefilling:
if request_is_prefilling:
# Make sure that we do not stop as even though this request did not create a token, it is still
# processing
stopped = False

View File

@ -165,7 +165,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
f"Batch ID {request.cached_batch.id} not found in cache."
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([batch, cached_batch])
batch = self.model.batch_type.concatenate([cached_batch, batch])
concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch)