fix logprobs?

This commit is contained in:
OlivierDehaene 2024-10-09 17:33:15 +02:00
parent 08953c5975
commit 3ace1b2f8d
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -956,7 +956,11 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_head_indices.append(torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64
))
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
@ -966,7 +970,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int32,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
@ -1029,9 +1033,7 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
@ -1822,6 +1824,7 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
@ -1840,6 +1843,7 @@ class FlashCausalLM(Model):
# Cumulative length
cumulative_length = 0
for i, (
request,
prompt_length,
cache_length,
input_length,
@ -1849,7 +1853,7 @@ class FlashCausalLM(Model):
request_is_prefilling,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
_start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
@ -1869,25 +1873,16 @@ class FlashCausalLM(Model):
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
# If the request was prefilling and cache_length == 0, the first token is a bogus token
# and needs to be removed. We do so by incrementing the start_index
if request_was_prefilling and cache_length == 0:
start_index += 1
# If the request was prefilling, and it is done prefilling, the last token was generated and is
# therefore not part of the prefill. We remove it by decrementing out_end_index
if request_was_prefilling and not request_is_prefilling:
out_end_index -= 1
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = (
batch.input_ids[start_index:end_index]
)
prefill_tokens_indices[out_start_index : out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[start_index:end_index]
prefill_tokens_indices = ids
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
@ -2031,30 +2026,30 @@ class FlashCausalLM(Model):
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# log_master(logger.info, f"{prefill_logprobs}")
if not request_is_prefilling:
# If the request is done prefilling, then the last logprob is a generated token
# The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it
out_end_index -= 1
request_prefill_logprobs = prefill_logprobs[
out_start_index:out_end_index
]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
prefill_token_ids = all_input_ids[
cache_length : cache_length + input_length
cache_length + 1 : cache_length + input_length + 1
]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token
# add nan for cached prompt tokens/first token
request_prefill_logprobs = [float("nan")] * (
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = (
all_input_ids[:cache_length] + prefill_token_ids
all_input_ids[:cache_length + 1] + prefill_token_ids
)
prefill_texts = self.tokenizer.batch_decode(
@ -2063,10 +2058,6 @@ class FlashCausalLM(Model):
skip_special_tokens=False,
)
# log_master(logger.info, f"{prefill_token_ids}")
# log_master(logger.info, f"{request_prefill_logprobs}")
# log_master(logger.info, f"{prefill_texts}")
prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,