mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix prefill logprobs
This commit is contained in:
parent
3924b87a04
commit
ea4b739a9f
@ -1757,6 +1757,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
finished_prefilling = True
|
finished_prefilling = True
|
||||||
next_chunk_lengths = []
|
next_chunk_lengths = []
|
||||||
|
current_prefilling_mask = batch.prefilling_mask
|
||||||
if prefill:
|
if prefill:
|
||||||
if get_support_chunking():
|
if get_support_chunking():
|
||||||
next_prefilling_mask = []
|
next_prefilling_mask = []
|
||||||
@ -1998,6 +1999,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
|
current_prefilling_mask,
|
||||||
batch.prefilling_mask,
|
batch.prefilling_mask,
|
||||||
accepted_ids,
|
accepted_ids,
|
||||||
batch_top_token_ids,
|
batch_top_token_ids,
|
||||||
@ -2021,7 +2023,8 @@ class FlashCausalLM(Model):
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
request_prefilling,
|
request_was_prefilling,
|
||||||
|
request_is_prefilling,
|
||||||
n_accepted_ids,
|
n_accepted_ids,
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
@ -2032,7 +2035,7 @@ class FlashCausalLM(Model):
|
|||||||
# this state to be stable
|
# this state to be stable
|
||||||
if request.id % self.world_size == self.rank:
|
if request.id % self.world_size == self.rank:
|
||||||
# Prefill
|
# Prefill
|
||||||
if request_prefilling and request.prefill_logprobs:
|
if request_was_prefilling and request.prefill_logprobs:
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
@ -2072,7 +2075,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefill_logprob_tokens[i] = None
|
batch.prefill_logprob_tokens[i] = None
|
||||||
|
|
||||||
# If it is, the tokens we decoded should be ignored
|
# If it is, the tokens we decoded should be ignored
|
||||||
if request_prefilling:
|
if request_is_prefilling:
|
||||||
# Make sure that we do not stop as even though this request did not create a token, it is still
|
# Make sure that we do not stop as even though this request did not create a token, it is still
|
||||||
# processing
|
# processing
|
||||||
stopped = False
|
stopped = False
|
||||||
|
@ -165,7 +165,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
f"Batch ID {request.cached_batch.id} not found in cache."
|
f"Batch ID {request.cached_batch.id} not found in cache."
|
||||||
)
|
)
|
||||||
start_concat = time.time_ns()
|
start_concat = time.time_ns()
|
||||||
batch = self.model.batch_type.concatenate([batch, cached_batch])
|
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
||||||
concat_ns = time.time_ns() - start_concat
|
concat_ns = time.time_ns() - start_concat
|
||||||
|
|
||||||
generations, next_batch, timings = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user