mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix logprobs?
This commit is contained in:
parent
08953c5975
commit
3ace1b2f8d
@ -956,7 +956,11 @@ class FlashCausalLMBatch(Batch):
|
||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||
|
||||
if prefill_logprobs:
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_head_indices.append(torch.arange(
|
||||
cumulative_length,
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64
|
||||
))
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
@ -966,7 +970,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1],
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
@ -1029,9 +1033,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||
)
|
||||
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
@ -1822,6 +1824,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.prompt_lengths,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
@ -1840,6 +1843,7 @@ class FlashCausalLM(Model):
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
for i, (
|
||||
request,
|
||||
prompt_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
@ -1849,7 +1853,7 @@ class FlashCausalLM(Model):
|
||||
request_is_prefilling,
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
_start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if prefill:
|
||||
@ -1869,25 +1873,16 @@ class FlashCausalLM(Model):
|
||||
]
|
||||
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if prefill_logprobs:
|
||||
# If the request was prefilling and cache_length == 0, the first token is a bogus token
|
||||
# and needs to be removed. We do so by incrementing the start_index
|
||||
if request_was_prefilling and cache_length == 0:
|
||||
start_index += 1
|
||||
|
||||
# If the request was prefilling, and it is done prefilling, the last token was generated and is
|
||||
# therefore not part of the prefill. We remove it by decrementing out_end_index
|
||||
if request_was_prefilling and not request_is_prefilling:
|
||||
out_end_index -= 1
|
||||
|
||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||
if request.prefill_logprobs and request_was_prefilling:
|
||||
# Logprobs generated by the model are for the next token
|
||||
# So we need to translate the id tensor by 1
|
||||
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
||||
if len(batch) > 1:
|
||||
prefill_tokens_indices[out_start_index:out_end_index] = (
|
||||
batch.input_ids[start_index:end_index]
|
||||
)
|
||||
prefill_tokens_indices[out_start_index : out_end_index] = ids
|
||||
else:
|
||||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = batch.input_ids[start_index:end_index]
|
||||
prefill_tokens_indices = ids
|
||||
|
||||
if not request_is_prefilling:
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
@ -2031,30 +2026,30 @@ class FlashCausalLM(Model):
|
||||
if request_was_prefilling and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# log_master(logger.info, f"{prefill_logprobs}")
|
||||
|
||||
if not request_is_prefilling:
|
||||
# If the request is done prefilling, then the last logprob is a generated token
|
||||
# The request is dones prefilling, meaning that we started generating new tokens
|
||||
# The last logprob is a logprob for a generated token that was not part of the prompt
|
||||
# We need to remove it
|
||||
out_end_index -= 1
|
||||
|
||||
request_prefill_logprobs = prefill_logprobs[
|
||||
out_start_index:out_end_index
|
||||
]
|
||||
# Logprobs generated by the model are for the next token
|
||||
# So we need to translate the id tensor by 1
|
||||
prefill_token_ids = all_input_ids[
|
||||
cache_length : cache_length + input_length
|
||||
cache_length + 1 : cache_length + input_length + 1
|
||||
]
|
||||
|
||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||
|
||||
if past_prefill_logprob_tokens is None:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
# add nan for cached prompt tokens/first token
|
||||
request_prefill_logprobs = [float("nan")] * (
|
||||
cache_length + 1
|
||||
) + request_prefill_logprobs
|
||||
prefill_token_ids = (
|
||||
all_input_ids[:cache_length] + prefill_token_ids
|
||||
all_input_ids[:cache_length + 1] + prefill_token_ids
|
||||
)
|
||||
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
@ -2063,10 +2058,6 @@ class FlashCausalLM(Model):
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# log_master(logger.info, f"{prefill_token_ids}")
|
||||
# log_master(logger.info, f"{request_prefill_logprobs}")
|
||||
# log_master(logger.info, f"{prefill_texts}")
|
||||
|
||||
prefill_logprob_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
|
Loading…
Reference in New Issue
Block a user