This commit is contained in:
OlivierDehaene 2024-10-09 20:04:06 +02:00
parent d73c5c634d
commit d361197aab
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -956,11 +956,13 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_head_indices.append(torch.arange(
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64
))
dtype=torch.int64,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
@ -1875,7 +1877,9 @@ class FlashCausalLM(Model):
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
ids = batch.all_input_ids_tensor[
i, cache_length + 1 : cache_length + input_length + 1
]
if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = ids
else:
@ -2014,6 +2018,12 @@ class FlashCausalLM(Model):
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
if all_input_ids[:2] == [1986, 374] and not request_is_prefilling:
log_master(
logger.info,
f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}",
)
# Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
@ -2114,7 +2124,6 @@ class FlashCausalLM(Model):
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
@ -2189,6 +2198,7 @@ class FlashCausalLM(Model):
)
# Update values
index += n_accepted_ids
current_cache_length = cache_length + input_length
batch.cache_lengths[i] = current_cache_length
current_input_length = new_input_length