mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix filter and concat
This commit is contained in:
parent
e4f9110e14
commit
a85f5ebecd
@ -573,11 +573,13 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
postfix_lengths = []
|
||||
prefix_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
adapter_set = set()
|
||||
@ -595,14 +597,15 @@ class FlashCausalLMBatch(Batch):
|
||||
requests.append(self.requests[idx])
|
||||
|
||||
# Get length
|
||||
request_input_length = self.postfix_lengths[idx]
|
||||
request_postfix_length = self.postfix_lengths[idx]
|
||||
prefix_length = self.prefix_lengths[idx]
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
max_seqlen = max(max_seqlen, request_postfix_length)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
postfix_lengths.append(request_input_length)
|
||||
prompt_lengths.append(self.prompt_lengths[idx])
|
||||
postfix_lengths.append(request_postfix_length)
|
||||
prefix_lengths.append(prefix_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
@ -626,17 +629,17 @@ class FlashCausalLMBatch(Batch):
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
slot_indices[i] = cumulative_max_length + request_input_length - 1
|
||||
slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
||||
|
||||
# Set slice
|
||||
slot_filtering_indices[
|
||||
self.start_slots[idx] : self.start_slots[idx]
|
||||
+ request_input_length
|
||||
+ request_postfix_length
|
||||
+ remaining_tokens
|
||||
- 1
|
||||
] = True
|
||||
|
||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
@ -647,6 +650,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
@ -683,6 +687,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths,
|
||||
@ -732,12 +738,13 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = max(
|
||||
max_length,
|
||||
max(
|
||||
input_length
|
||||
prefix_length
|
||||
+ postfix_length
|
||||
+ stopping_criteria.max_new_tokens
|
||||
+ speculative_length
|
||||
- stopping_criteria.current_tokens
|
||||
for input_length, stopping_criteria in zip(
|
||||
b.postfix_lengths, b.stopping_criterias
|
||||
for prefix_length, postfix_length, stopping_criteria in zip(
|
||||
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
|
||||
)
|
||||
),
|
||||
)
|
||||
@ -746,6 +753,9 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
@ -776,6 +786,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
postfix_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
@ -809,6 +820,7 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
@ -845,6 +857,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
prefix_ids.extend(batch.prefix_ids)
|
||||
|
||||
prompt_lengths.extend(batch.prompt_lengths)
|
||||
postfix_lengths.extend(batch.postfix_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
@ -898,6 +911,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
|
Loading…
Reference in New Issue
Block a user