mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
idk at this point
This commit is contained in:
parent
3ace1b2f8d
commit
57f55fe834
@ -758,11 +758,12 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
slot_indices[start_index:end_index] = (
|
||||
batch.slot_indices + cumulative_slots
|
||||
)
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
|
||||
|
||||
# Copy over adapter indices
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
@ -779,7 +780,6 @@ class FlashCausalLMBatch(Batch):
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
|
||||
|
||||
# Update
|
||||
cumulative_slots += len(batch.slots)
|
||||
@ -1614,13 +1614,12 @@ class FlashCausalLM(Model):
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (input_lengths + cache_lengths_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -1852,16 +1851,11 @@ class FlashCausalLM(Model):
|
||||
request_was_prefilling,
|
||||
request_is_prefilling,
|
||||
) in enumerate(iterator):
|
||||
if prefill and finished_prefilling:
|
||||
# Indexing metadata
|
||||
_start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if prefill:
|
||||
# Indexing metadata
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
if finished_prefilling:
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
@ -1875,6 +1869,10 @@ class FlashCausalLM(Model):
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||
if request.prefill_logprobs and request_was_prefilling:
|
||||
# Indexing metadata
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# Logprobs generated by the model are for the next token
|
||||
# So we need to translate the id tensor by 1
|
||||
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
|
||||
@ -1888,10 +1886,9 @@ class FlashCausalLM(Model):
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
||||
next_input_ids[index]
|
||||
next_input_ids[index + j]
|
||||
)
|
||||
index += 1
|
||||
|
||||
index += n_accepted_ids
|
||||
cumulative_length += input_length
|
||||
|
||||
# Update values
|
||||
|
Loading…
Reference in New Issue
Block a user