fix logprobs?

This commit is contained in:
OlivierDehaene 2024-10-09 17:33:15 +02:00
parent 08953c5975
commit 3ace1b2f8d
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -956,7 +956,11 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs: if prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length) prefill_head_indices.append(torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64
))
prefill_next_token_indices.append( prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1 prefill_out_cumulative_length + input_length - 1
) )
@ -966,7 +970,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices.append( prefill_head_indices.append(
torch.tensor( torch.tensor(
[cumulative_length + input_length - 1], [cumulative_length + input_length - 1],
dtype=torch.int32, dtype=torch.int64,
) )
) )
prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_next_token_indices.append(prefill_out_cumulative_length)
@ -1029,9 +1033,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None prefill_next_token_indices = None
else: else:
prefill_head_indices = torch.tensor( prefill_head_indices = torch.cat(prefill_head_indices).to(device)
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor( prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device prefill_next_token_indices, dtype=torch.int64, device=device
) )
@ -1822,6 +1824,7 @@ class FlashCausalLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests,
batch.prompt_lengths, batch.prompt_lengths,
batch.cache_lengths, batch.cache_lengths,
batch.input_lengths, batch.input_lengths,
@ -1840,6 +1843,7 @@ class FlashCausalLM(Model):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
for i, ( for i, (
request,
prompt_length, prompt_length,
cache_length, cache_length,
input_length, input_length,
@ -1849,7 +1853,7 @@ class FlashCausalLM(Model):
request_is_prefilling, request_is_prefilling,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length _start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if prefill: if prefill:
@ -1869,25 +1873,16 @@ class FlashCausalLM(Model):
] ]
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices # Copy batch.all_input_ids_tensor to prefill_token_indices
if prefill_logprobs: if request.prefill_logprobs and request_was_prefilling:
# If the request was prefilling and cache_length == 0, the first token is a bogus token # Logprobs generated by the model are for the next token
# and needs to be removed. We do so by incrementing the start_index # So we need to translate the id tensor by 1
if request_was_prefilling and cache_length == 0: ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
start_index += 1
# If the request was prefilling, and it is done prefilling, the last token was generated and is
# therefore not part of the prefill. We remove it by decrementing out_end_index
if request_was_prefilling and not request_is_prefilling:
out_end_index -= 1
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = ( prefill_tokens_indices[out_start_index : out_end_index] = ids
batch.input_ids[start_index:end_index]
)
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[start_index:end_index] prefill_tokens_indices = ids
if not request_is_prefilling: if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request # Only save tokens if we are done prefilling for this request
@ -2031,30 +2026,30 @@ class FlashCausalLM(Model):
if request_was_prefilling and request.prefill_logprobs: if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i] out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
# log_master(logger.info, f"{prefill_logprobs}")
if not request_is_prefilling: if not request_is_prefilling:
# If the request is done prefilling, then the last logprob is a generated token # The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it # We need to remove it
out_end_index -= 1 out_end_index -= 1
request_prefill_logprobs = prefill_logprobs[ request_prefill_logprobs = prefill_logprobs[
out_start_index:out_end_index out_start_index:out_end_index
] ]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
prefill_token_ids = all_input_ids[ prefill_token_ids = all_input_ids[
cache_length : cache_length + input_length cache_length + 1 : cache_length + input_length + 1
] ]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i] past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None: if past_prefill_logprob_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token # add nan for cached prompt tokens/first token
request_prefill_logprobs = [float("nan")] * ( request_prefill_logprobs = [float("nan")] * (
cache_length + 1 cache_length + 1
) + request_prefill_logprobs ) + request_prefill_logprobs
prefill_token_ids = ( prefill_token_ids = (
all_input_ids[:cache_length] + prefill_token_ids all_input_ids[:cache_length + 1] + prefill_token_ids
) )
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
@ -2063,10 +2058,6 @@ class FlashCausalLM(Model):
skip_special_tokens=False, skip_special_tokens=False,
) )
# log_master(logger.info, f"{prefill_token_ids}")
# log_master(logger.info, f"{request_prefill_logprobs}")
# log_master(logger.info, f"{prefill_texts}")
prefill_logprob_tokens = Tokens( prefill_logprob_tokens = Tokens(
prefill_token_ids, prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,