mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
omfg
This commit is contained in:
parent
d73c5c634d
commit
d361197aab
@ -956,11 +956,13 @@ class FlashCausalLMBatch(Batch):
|
||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||
|
||||
if prefill_logprobs:
|
||||
prefill_head_indices.append(torch.arange(
|
||||
prefill_head_indices.append(
|
||||
torch.arange(
|
||||
cumulative_length,
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64
|
||||
))
|
||||
dtype=torch.int64,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
@ -1875,7 +1877,9 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Logprobs generated by the model are for the next token
|
||||
# So we need to translate the id tensor by 1
|
||||
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
||||
ids = batch.all_input_ids_tensor[
|
||||
i, cache_length + 1 : cache_length + input_length + 1
|
||||
]
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[out_start_index:out_end_index] = ids
|
||||
else:
|
||||
@ -2014,6 +2018,12 @@ class FlashCausalLM(Model):
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
if all_input_ids[:2] == [1986, 374] and not request_is_prefilling:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}",
|
||||
)
|
||||
|
||||
# Compute logprobs first as, even though we might skip the token,
|
||||
# it can still be required to compute the logprobs
|
||||
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
|
||||
@ -2114,7 +2124,6 @@ class FlashCausalLM(Model):
|
||||
_next_token_logprobs = next_token_logprobs[
|
||||
index : index + n_accepted_ids - left
|
||||
]
|
||||
index += n_accepted_ids
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
@ -2189,6 +2198,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
# Update values
|
||||
index += n_accepted_ids
|
||||
current_cache_length = cache_length + input_length
|
||||
batch.cache_lengths[i] = current_cache_length
|
||||
current_input_length = new_input_length
|
||||
|
Loading…
Reference in New Issue
Block a user