idk at this point

This commit is contained in:
OlivierDehaene 2024-10-09 19:17:18 +02:00
parent 3ace1b2f8d
commit 57f55fe834
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -758,11 +758,12 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slots[slots_start_index:slots_end_index] = batch.slots
slot_indices[start_index:end_index] = ( slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots batch.slot_indices + cumulative_slots
) )
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Copy over adapter indices # Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size adapter_start_index = cumulative_adapter_indices_size
@ -779,7 +780,6 @@ class FlashCausalLMBatch(Batch):
batch.adapter_meta.adapter_segments, batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices, batch.adapter_meta.segment_indices,
) )
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Update # Update
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
@ -1614,13 +1614,12 @@ class FlashCausalLM(Model):
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + cache_lengths_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=batch.max_current_length,
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -1852,16 +1851,11 @@ class FlashCausalLM(Model):
request_was_prefilling, request_was_prefilling,
request_is_prefilling, request_is_prefilling,
) in enumerate(iterator): ) in enumerate(iterator):
if prefill and finished_prefilling:
# Indexing metadata # Indexing metadata
_start_index = cumulative_length _start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if prefill:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
if finished_prefilling:
# Initialize position_ids # Initialize position_ids
# In decode, we do not need this as we can just increment position ids # In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1] next_position_ids[i] = batch.position_ids[end_index - 1]
@ -1875,6 +1869,10 @@ class FlashCausalLM(Model):
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices # Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling: if request.prefill_logprobs and request_was_prefilling:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Logprobs generated by the model are for the next token # Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1 # So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1] ids = batch.all_input_ids_tensor[i, cache_length + 1: cache_length + input_length + 1]
@ -1888,10 +1886,9 @@ class FlashCausalLM(Model):
# Only save tokens if we are done prefilling for this request # Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids): for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
next_input_ids[index] next_input_ids[index + j]
) )
index += 1 index += n_accepted_ids
cumulative_length += input_length cumulative_length += input_length
# Update values # Update values