This commit is contained in:
OlivierDehaene 2024-10-09 20:04:06 +02:00
parent d73c5c634d
commit d361197aab
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -956,11 +956,13 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs: if prefill_logprobs:
prefill_head_indices.append(torch.arange( prefill_head_indices.append(
cumulative_length, torch.arange(
cumulative_length + input_length, cumulative_length,
dtype=torch.int64 cumulative_length + input_length,
)) dtype=torch.int64,
)
)
prefill_next_token_indices.append( prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1 prefill_out_cumulative_length + input_length - 1
) )
@ -1875,9 +1877,11 @@ class FlashCausalLM(Model):
# Logprobs generated by the model are for the next token # Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1 # So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] ids = batch.all_input_ids_tensor[
i, cache_length + 1 : cache_length + input_length + 1
]
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index] = ids prefill_tokens_indices[out_start_index:out_end_index] = ids
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids prefill_tokens_indices = ids
@ -2014,6 +2018,12 @@ class FlashCausalLM(Model):
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
if all_input_ids[:2] == [1986, 374] and not request_is_prefilling:
log_master(
logger.info,
f"{request.id} {next_token_ids} {self.tokenizer.batch_decode(next_token_ids)}",
)
# Compute logprobs first as, even though we might skip the token, # Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs # it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
@ -2046,7 +2056,7 @@ class FlashCausalLM(Model):
cache_length + 1 cache_length + 1
) + request_prefill_logprobs ) + request_prefill_logprobs
prefill_token_ids = ( prefill_token_ids = (
all_input_ids[:cache_length + 1] + prefill_token_ids all_input_ids[: cache_length + 1] + prefill_token_ids
) )
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
@ -2114,7 +2124,6 @@ class FlashCausalLM(Model):
_next_token_logprobs = next_token_logprobs[ _next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left index : index + n_accepted_ids - left
] ]
index += n_accepted_ids
# Shard generations # Shard generations
# All generations will be appended in the rust sharded client # All generations will be appended in the rust sharded client
@ -2189,6 +2198,7 @@ class FlashCausalLM(Model):
) )
# Update values # Update values
index += n_accepted_ids
current_cache_length = cache_length + input_length current_cache_length = cache_length + input_length
batch.cache_lengths[i] = current_cache_length batch.cache_lengths[i] = current_cache_length
current_input_length = new_input_length current_input_length = new_input_length