idk at this point

This commit is contained in:
OlivierDehaene 2024-10-09 19:17:18 +02:00
parent 3ace1b2f8d
commit 57f55fe834
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -758,11 +758,12 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slots[slots_start_index:slots_end_index] = batch.slots
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
)
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
@ -779,7 +780,6 @@ class FlashCausalLMBatch(Batch):
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Update
cumulative_slots += len(batch.slots)
@ -1614,13 +1614,12 @@ class FlashCausalLM(Model):
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + cache_lengths_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -1852,46 +1851,44 @@ class FlashCausalLM(Model):
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
# Indexing metadata
_start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill and finished_prefilling:
# Indexing metadata
_start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
if finished_prefilling:
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
next_input_ids[index]
next_input_ids[index + j]
)
index += 1
index += n_accepted_ids
cumulative_length += input_length
# Update values