mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix logprobs?
This commit is contained in:
parent
08953c5975
commit
3ace1b2f8d
@ -956,7 +956,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||||
|
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
prefill_head_indices.append(torch.arange(
|
||||||
|
cumulative_length,
|
||||||
|
cumulative_length + input_length,
|
||||||
|
dtype=torch.int64
|
||||||
|
))
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
prefill_out_cumulative_length + input_length - 1
|
prefill_out_cumulative_length + input_length - 1
|
||||||
)
|
)
|
||||||
@ -966,7 +970,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices.append(
|
prefill_head_indices.append(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[cumulative_length + input_length - 1],
|
[cumulative_length + input_length - 1],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||||
@ -1029,9 +1033,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||||
prefill_next_token_indices = None
|
prefill_next_token_indices = None
|
||||||
else:
|
else:
|
||||||
prefill_head_indices = torch.tensor(
|
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
|
||||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
prefill_next_token_indices = torch.tensor(
|
prefill_next_token_indices = torch.tensor(
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
@ -1822,6 +1824,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
|
batch.requests,
|
||||||
batch.prompt_lengths,
|
batch.prompt_lengths,
|
||||||
batch.cache_lengths,
|
batch.cache_lengths,
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
@ -1840,6 +1843,7 @@ class FlashCausalLM(Model):
|
|||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
for i, (
|
for i, (
|
||||||
|
request,
|
||||||
prompt_length,
|
prompt_length,
|
||||||
cache_length,
|
cache_length,
|
||||||
input_length,
|
input_length,
|
||||||
@ -1849,7 +1853,7 @@ class FlashCausalLM(Model):
|
|||||||
request_is_prefilling,
|
request_is_prefilling,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
_start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
@ -1869,25 +1873,16 @@ class FlashCausalLM(Model):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if request.prefill_logprobs and request_was_prefilling:
|
||||||
# If the request was prefilling and cache_length == 0, the first token is a bogus token
|
# Logprobs generated by the model are for the next token
|
||||||
# and needs to be removed. We do so by incrementing the start_index
|
# So we need to translate the id tensor by 1
|
||||||
if request_was_prefilling and cache_length == 0:
|
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
||||||
start_index += 1
|
|
||||||
|
|
||||||
# If the request was prefilling, and it is done prefilling, the last token was generated and is
|
|
||||||
# therefore not part of the prefill. We remove it by decrementing out_end_index
|
|
||||||
if request_was_prefilling and not request_is_prefilling:
|
|
||||||
out_end_index -= 1
|
|
||||||
|
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[out_start_index:out_end_index] = (
|
prefill_tokens_indices[out_start_index : out_end_index] = ids
|
||||||
batch.input_ids[start_index:end_index]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids[start_index:end_index]
|
prefill_tokens_indices = ids
|
||||||
|
|
||||||
if not request_is_prefilling:
|
if not request_is_prefilling:
|
||||||
# Only save tokens if we are done prefilling for this request
|
# Only save tokens if we are done prefilling for this request
|
||||||
@ -2031,30 +2026,30 @@ class FlashCausalLM(Model):
|
|||||||
if request_was_prefilling and request.prefill_logprobs:
|
if request_was_prefilling and request.prefill_logprobs:
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
# log_master(logger.info, f"{prefill_logprobs}")
|
|
||||||
|
|
||||||
if not request_is_prefilling:
|
if not request_is_prefilling:
|
||||||
# If the request is done prefilling, then the last logprob is a generated token
|
# The request is dones prefilling, meaning that we started generating new tokens
|
||||||
|
# The last logprob is a logprob for a generated token that was not part of the prompt
|
||||||
# We need to remove it
|
# We need to remove it
|
||||||
out_end_index -= 1
|
out_end_index -= 1
|
||||||
|
|
||||||
request_prefill_logprobs = prefill_logprobs[
|
request_prefill_logprobs = prefill_logprobs[
|
||||||
out_start_index:out_end_index
|
out_start_index:out_end_index
|
||||||
]
|
]
|
||||||
|
# Logprobs generated by the model are for the next token
|
||||||
|
# So we need to translate the id tensor by 1
|
||||||
prefill_token_ids = all_input_ids[
|
prefill_token_ids = all_input_ids[
|
||||||
cache_length : cache_length + input_length
|
cache_length + 1 : cache_length + input_length + 1
|
||||||
]
|
]
|
||||||
|
|
||||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||||
|
|
||||||
if past_prefill_logprob_tokens is None:
|
if past_prefill_logprob_tokens is None:
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# add nan for cached prompt tokens/first token
|
||||||
request_prefill_logprobs = [float("nan")] * (
|
request_prefill_logprobs = [float("nan")] * (
|
||||||
cache_length + 1
|
cache_length + 1
|
||||||
) + request_prefill_logprobs
|
) + request_prefill_logprobs
|
||||||
prefill_token_ids = (
|
prefill_token_ids = (
|
||||||
all_input_ids[:cache_length] + prefill_token_ids
|
all_input_ids[:cache_length + 1] + prefill_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
@ -2063,10 +2058,6 @@ class FlashCausalLM(Model):
|
|||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# log_master(logger.info, f"{prefill_token_ids}")
|
|
||||||
# log_master(logger.info, f"{request_prefill_logprobs}")
|
|
||||||
# log_master(logger.info, f"{prefill_texts}")
|
|
||||||
|
|
||||||
prefill_logprob_tokens = Tokens(
|
prefill_logprob_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
|
Loading…
Reference in New Issue
Block a user