mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix prefill logprobs
This commit is contained in:
parent
3924b87a04
commit
ea4b739a9f
@ -1757,6 +1757,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
finished_prefilling = True
|
||||
next_chunk_lengths = []
|
||||
current_prefilling_mask = batch.prefilling_mask
|
||||
if prefill:
|
||||
if get_support_chunking():
|
||||
next_prefilling_mask = []
|
||||
@ -1998,6 +1999,7 @@ class FlashCausalLM(Model):
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
current_prefilling_mask,
|
||||
batch.prefilling_mask,
|
||||
accepted_ids,
|
||||
batch_top_token_ids,
|
||||
@ -2021,7 +2023,8 @@ class FlashCausalLM(Model):
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
request_prefilling,
|
||||
request_was_prefilling,
|
||||
request_is_prefilling,
|
||||
n_accepted_ids,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
@ -2032,7 +2035,7 @@ class FlashCausalLM(Model):
|
||||
# this state to be stable
|
||||
if request.id % self.world_size == self.rank:
|
||||
# Prefill
|
||||
if request_prefilling and request.prefill_logprobs:
|
||||
if request_was_prefilling and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
@ -2072,7 +2075,7 @@ class FlashCausalLM(Model):
|
||||
batch.prefill_logprob_tokens[i] = None
|
||||
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
if request_prefilling:
|
||||
if request_is_prefilling:
|
||||
# Make sure that we do not stop as even though this request did not create a token, it is still
|
||||
# processing
|
||||
stopped = False
|
||||
|
@ -165,7 +165,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
f"Batch ID {request.cached_batch.id} not found in cache."
|
||||
)
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate([batch, cached_batch])
|
||||
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
|
Loading…
Reference in New Issue
Block a user