mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
idk at this point
This commit is contained in:
parent
3ace1b2f8d
commit
57f55fe834
@ -758,11 +758,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
slot_indices[start_index:end_index] = (
|
slot_indices[start_index:end_index] = (
|
||||||
batch.slot_indices + cumulative_slots
|
batch.slot_indices + cumulative_slots
|
||||||
)
|
)
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
|
||||||
|
|
||||||
# Copy over adapter indices
|
# Copy over adapter indices
|
||||||
adapter_start_index = cumulative_adapter_indices_size
|
adapter_start_index = cumulative_adapter_indices_size
|
||||||
@ -779,7 +780,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
batch.adapter_meta.adapter_segments,
|
batch.adapter_meta.adapter_segments,
|
||||||
batch.adapter_meta.segment_indices,
|
batch.adapter_meta.segment_indices,
|
||||||
)
|
)
|
||||||
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots)
|
||||||
@ -1614,13 +1614,12 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
cache_lengths_tensor=cache_lengths_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + cache_lengths_tensor).max().item()
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=max_s,
|
||||||
max_k=max_k,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -1852,46 +1851,44 @@ class FlashCausalLM(Model):
|
|||||||
request_was_prefilling,
|
request_was_prefilling,
|
||||||
request_is_prefilling,
|
request_is_prefilling,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
if prefill and finished_prefilling:
|
||||||
_start_index = cumulative_length
|
# Indexing metadata
|
||||||
end_index = cumulative_length + input_length
|
_start_index = cumulative_length
|
||||||
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
if prefill:
|
# Initialize position_ids
|
||||||
|
# In decode, we do not need this as we can just increment position ids
|
||||||
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
|
# Initialize adapter indices
|
||||||
|
# In decode, we only have one token per row in the batch, so grab last index
|
||||||
|
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
|
||||||
|
end_index - 1
|
||||||
|
]
|
||||||
|
|
||||||
|
# Used to gather prefill logprobs
|
||||||
|
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||||
|
if request.prefill_logprobs and request_was_prefilling:
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
if finished_prefilling:
|
# Logprobs generated by the model are for the next token
|
||||||
# Initialize position_ids
|
# So we need to translate the id tensor by 1
|
||||||
# In decode, we do not need this as we can just increment position ids
|
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
if len(batch) > 1:
|
||||||
|
prefill_tokens_indices[out_start_index : out_end_index] = ids
|
||||||
# Initialize adapter indices
|
else:
|
||||||
# In decode, we only have one token per row in the batch, so grab last index
|
# Set prefill_tokens_indices to the correct slice
|
||||||
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
|
prefill_tokens_indices = ids
|
||||||
end_index - 1
|
|
||||||
]
|
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
|
||||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
|
||||||
if request.prefill_logprobs and request_was_prefilling:
|
|
||||||
# Logprobs generated by the model are for the next token
|
|
||||||
# So we need to translate the id tensor by 1
|
|
||||||
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
|
||||||
if len(batch) > 1:
|
|
||||||
prefill_tokens_indices[out_start_index : out_end_index] = ids
|
|
||||||
else:
|
|
||||||
# Set prefill_tokens_indices to the correct slice
|
|
||||||
prefill_tokens_indices = ids
|
|
||||||
|
|
||||||
if not request_is_prefilling:
|
if not request_is_prefilling:
|
||||||
# Only save tokens if we are done prefilling for this request
|
# Only save tokens if we are done prefilling for this request
|
||||||
for j in range(n_accepted_ids):
|
for j in range(n_accepted_ids):
|
||||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
||||||
next_input_ids[index]
|
next_input_ids[index + j]
|
||||||
)
|
)
|
||||||
index += 1
|
index += n_accepted_ids
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
Loading…
Reference in New Issue
Block a user