fix filter and concat

This commit is contained in:
OlivierDehaene 2024-09-25 15:34:08 +02:00
parent e4f9110e14
commit a85f5ebecd
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -573,11 +573,13 @@ class FlashCausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
prompt_lengths = []
postfix_lengths = [] postfix_lengths = []
prefix_lengths = [] prefix_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
adapter_set = set() adapter_set = set()
@ -595,14 +597,15 @@ class FlashCausalLMBatch(Batch):
requests.append(self.requests[idx]) requests.append(self.requests[idx])
# Get length # Get length
request_input_length = self.postfix_lengths[idx] request_postfix_length = self.postfix_lengths[idx]
prefix_length = self.prefix_lengths[idx] prefix_length = self.prefix_lengths[idx]
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_postfix_length)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx]) prefix_ids.append(self.prefix_ids[idx])
postfix_lengths.append(request_input_length) prompt_lengths.append(self.prompt_lengths[idx])
postfix_lengths.append(request_postfix_length)
prefix_lengths.append(prefix_length) prefix_lengths.append(prefix_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
@ -626,17 +629,17 @@ class FlashCausalLMBatch(Batch):
start_slots.append(cumulative_max_length) start_slots.append(cumulative_max_length)
# Copy to tensor (CPU) # Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length + request_input_length - 1 slot_indices[i] = cumulative_max_length + request_postfix_length - 1
# Set slice # Set slice
slot_filtering_indices[ slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx] self.start_slots[idx] : self.start_slots[idx]
+ request_input_length + request_postfix_length
+ remaining_tokens + remaining_tokens
- 1 - 1
] = True ] = True
cumulative_max_length += request_input_length + remaining_tokens - 1 cumulative_max_length += request_postfix_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
@ -647,6 +650,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
@ -683,6 +687,8 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths, postfix_lengths=postfix_lengths,
postfix_lengths_tensor=postfix_lengths_tensor, postfix_lengths_tensor=postfix_lengths_tensor,
prefix_lengths=prefix_lengths, prefix_lengths=prefix_lengths,
@ -732,12 +738,13 @@ class FlashCausalLMBatch(Batch):
max_length = max( max_length = max(
max_length, max_length,
max( max(
input_length prefix_length
+ postfix_length
+ stopping_criteria.max_new_tokens + stopping_criteria.max_new_tokens
+ speculative_length + speculative_length
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
for input_length, stopping_criteria in zip( for prefix_length, postfix_length, stopping_criteria in zip(
b.postfix_lengths, b.stopping_criterias b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
) )
), ),
) )
@ -746,6 +753,9 @@ class FlashCausalLMBatch(Batch):
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots) slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size
)
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
@ -776,6 +786,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
prompt_lengths = []
postfix_lengths = [] postfix_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -809,6 +820,7 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
@ -845,6 +857,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids) prefix_ids.extend(batch.prefix_ids)
prompt_lengths.extend(batch.prompt_lengths)
postfix_lengths.extend(batch.postfix_lengths) postfix_lengths.extend(batch.postfix_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
@ -898,6 +911,8 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths, postfix_lengths=postfix_lengths,
postfix_lengths_tensor=postfix_lengths_tensor, postfix_lengths_tensor=postfix_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,