fix filter and concat

This commit is contained in:
OlivierDehaene 2024-09-25 15:34:08 +02:00
parent e4f9110e14
commit a85f5ebecd
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -573,11 +573,13 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
prefix_ids = []
prompt_lengths = []
postfix_lengths = []
prefix_lengths = []
prefix_offsets = []
read_offsets = []
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
@ -595,14 +597,15 @@ class FlashCausalLMBatch(Batch):
requests.append(self.requests[idx])
# Get length
request_input_length = self.postfix_lengths[idx]
request_postfix_length = self.postfix_lengths[idx]
prefix_length = self.prefix_lengths[idx]
max_seqlen = max(max_seqlen, request_input_length)
max_seqlen = max(max_seqlen, request_postfix_length)
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
postfix_lengths.append(request_input_length)
prompt_lengths.append(self.prompt_lengths[idx])
postfix_lengths.append(request_postfix_length)
prefix_lengths.append(prefix_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
@ -626,17 +629,17 @@ class FlashCausalLMBatch(Batch):
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length + request_input_length - 1
slot_indices[i] = cumulative_max_length + request_postfix_length - 1
# Set slice
slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx]
+ request_input_length
+ request_postfix_length
+ remaining_tokens
- 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
cumulative_max_length += request_postfix_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
@ -647,6 +650,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
@ -683,6 +687,8 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths,
postfix_lengths_tensor=postfix_lengths_tensor,
prefix_lengths=prefix_lengths,
@ -732,12 +738,13 @@ class FlashCausalLMBatch(Batch):
max_length = max(
max_length,
max(
input_length
prefix_length
+ postfix_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.postfix_lengths, b.stopping_criterias
for prefix_length, postfix_length, stopping_criteria in zip(
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
)
),
)
@ -746,6 +753,9 @@ class FlashCausalLMBatch(Batch):
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size
)
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
total_batch_size
)
@ -776,6 +786,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
prefix_ids = []
prompt_lengths = []
postfix_lengths = []
prefix_offsets = []
read_offsets = []
@ -809,6 +820,7 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
@ -845,6 +857,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
prompt_lengths.extend(batch.prompt_lengths)
postfix_lengths.extend(batch.postfix_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
@ -898,6 +911,8 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths,
postfix_lengths_tensor=postfix_lengths_tensor,
prefix_offsets=prefix_offsets,