mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix filter and concat
This commit is contained in:
parent
e4f9110e14
commit
a85f5ebecd
@ -573,11 +573,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
|
||||||
|
prompt_lengths = []
|
||||||
postfix_lengths = []
|
postfix_lengths = []
|
||||||
prefix_lengths = []
|
prefix_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
adapter_set = set()
|
adapter_set = set()
|
||||||
@ -595,14 +597,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests.append(self.requests[idx])
|
requests.append(self.requests[idx])
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.postfix_lengths[idx]
|
request_postfix_length = self.postfix_lengths[idx]
|
||||||
prefix_length = self.prefix_lengths[idx]
|
prefix_length = self.prefix_lengths[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_postfix_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
postfix_lengths.append(request_input_length)
|
prompt_lengths.append(self.prompt_lengths[idx])
|
||||||
|
postfix_lengths.append(request_postfix_length)
|
||||||
prefix_lengths.append(prefix_length)
|
prefix_lengths.append(prefix_length)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
@ -626,17 +629,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
slot_indices[i] = cumulative_max_length + request_input_length - 1
|
slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
||||||
|
|
||||||
# Set slice
|
# Set slice
|
||||||
slot_filtering_indices[
|
slot_filtering_indices[
|
||||||
self.start_slots[idx] : self.start_slots[idx]
|
self.start_slots[idx] : self.start_slots[idx]
|
||||||
+ request_input_length
|
+ request_postfix_length
|
||||||
+ remaining_tokens
|
+ remaining_tokens
|
||||||
- 1
|
- 1
|
||||||
] = True
|
] = True
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
@ -647,6 +650,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||||
|
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
@ -683,6 +687,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
|
prompt_lengths=prompt_lengths,
|
||||||
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
postfix_lengths=postfix_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||||
prefix_lengths=prefix_lengths,
|
prefix_lengths=prefix_lengths,
|
||||||
@ -732,12 +738,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_length = max(
|
max_length = max(
|
||||||
max_length,
|
max_length,
|
||||||
max(
|
max(
|
||||||
input_length
|
prefix_length
|
||||||
|
+ postfix_length
|
||||||
+ stopping_criteria.max_new_tokens
|
+ stopping_criteria.max_new_tokens
|
||||||
+ speculative_length
|
+ speculative_length
|
||||||
- stopping_criteria.current_tokens
|
- stopping_criteria.current_tokens
|
||||||
for input_length, stopping_criteria in zip(
|
for prefix_length, postfix_length, stopping_criteria in zip(
|
||||||
b.postfix_lengths, b.stopping_criterias
|
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -746,6 +753,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
|
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||||
|
total_batch_size
|
||||||
|
)
|
||||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
)
|
)
|
||||||
@ -776,6 +786,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
|
||||||
|
prompt_lengths = []
|
||||||
postfix_lengths = []
|
postfix_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
@ -809,6 +820,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
|
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||||
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
@ -845,6 +857,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
prefix_ids.extend(batch.prefix_ids)
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
|
||||||
|
prompt_lengths.extend(batch.prompt_lengths)
|
||||||
postfix_lengths.extend(batch.postfix_lengths)
|
postfix_lengths.extend(batch.postfix_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
@ -898,6 +911,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
|
prompt_lengths=prompt_lengths,
|
||||||
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
postfix_lengths=postfix_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
|
Loading…
Reference in New Issue
Block a user