mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
omfg
This commit is contained in:
parent
d73c5c634d
commit
d361197aab
@ -956,11 +956,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||||
|
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
prefill_head_indices.append(torch.arange(
|
prefill_head_indices.append(
|
||||||
|
torch.arange(
|
||||||
cumulative_length,
|
cumulative_length,
|
||||||
cumulative_length + input_length,
|
cumulative_length + input_length,
|
||||||
dtype=torch.int64
|
dtype=torch.int64,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
prefill_out_cumulative_length + input_length - 1
|
prefill_out_cumulative_length + input_length - 1
|
||||||
)
|
)
|
||||||
@ -1875,9 +1877,11 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Logprobs generated by the model are for the next token
|
# Logprobs generated by the model are for the next token
|
||||||
# So we need to translate the id tensor by 1
|
# So we need to translate the id tensor by 1
|
||||||
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
ids = batch.all_input_ids_tensor[
|
||||||
|
i, cache_length + 1 : cache_length + input_length + 1
|
||||||
|
]
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[out_start_index : out_end_index] = ids
|
prefill_tokens_indices[out_start_index:out_end_index] = ids
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = ids
|
prefill_tokens_indices = ids
|
||||||
@ -2014,6 +2018,12 @@ class FlashCausalLM(Model):
|
|||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
if all_input_ids[:2] == [1986, 374] and not request_is_prefilling:
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}",
|
||||||
|
)
|
||||||
|
|
||||||
# Compute logprobs first as, even though we might skip the token,
|
# Compute logprobs first as, even though we might skip the token,
|
||||||
# it can still be required to compute the logprobs
|
# it can still be required to compute the logprobs
|
||||||
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
|
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
|
||||||
@ -2046,7 +2056,7 @@ class FlashCausalLM(Model):
|
|||||||
cache_length + 1
|
cache_length + 1
|
||||||
) + request_prefill_logprobs
|
) + request_prefill_logprobs
|
||||||
prefill_token_ids = (
|
prefill_token_ids = (
|
||||||
all_input_ids[:cache_length + 1] + prefill_token_ids
|
all_input_ids[: cache_length + 1] + prefill_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
@ -2114,7 +2124,6 @@ class FlashCausalLM(Model):
|
|||||||
_next_token_logprobs = next_token_logprobs[
|
_next_token_logprobs = next_token_logprobs[
|
||||||
index : index + n_accepted_ids - left
|
index : index + n_accepted_ids - left
|
||||||
]
|
]
|
||||||
index += n_accepted_ids
|
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
@ -2189,6 +2198,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
index += n_accepted_ids
|
||||||
current_cache_length = cache_length + input_length
|
current_cache_length = cache_length + input_length
|
||||||
batch.cache_lengths[i] = current_cache_length
|
batch.cache_lengths[i] = current_cache_length
|
||||||
current_input_length = new_input_length
|
current_input_length = new_input_length
|
||||||
|
Loading…
Reference in New Issue
Block a user