mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix slot_filtering_indices
This commit is contained in:
parent
b49978ff67
commit
ff4155dfea
@ -175,7 +175,6 @@ pub(crate) async fn batching_task(
|
|||||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||||
// Since the next batch will be concatenated with the current batch,
|
// Since the next batch will be concatenated with the current batch,
|
||||||
// the current batch tokens must be subtracted to the prefill budget
|
// the current batch tokens must be subtracted to the prefill budget
|
||||||
// In the future, we could concatenate beforehand
|
|
||||||
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
||||||
// We can ignore min_size and max_size
|
// We can ignore min_size and max_size
|
||||||
// Models than rely on max_size cannot support chunking
|
// Models than rely on max_size cannot support chunking
|
||||||
|
@ -138,9 +138,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
speculative_ids: Optional[torch.Tensor]
|
speculative_ids: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Set when creating the batch
|
# Set when creating the batch
|
||||||
# CPU tensor of length b indicating the start of each sequence in slots
|
|
||||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
|
||||||
start_slots: Optional[torch.Tensor]
|
|
||||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||||
slot_indices: Optional[torch.Tensor]
|
slot_indices: Optional[torch.Tensor]
|
||||||
@ -417,7 +414,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
start_slots=None,
|
|
||||||
slot_indices=None,
|
slot_indices=None,
|
||||||
slots=None,
|
slots=None,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
@ -462,12 +458,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
max_postfix_length = 0
|
max_postfix_length = 0
|
||||||
max_current_length = 0
|
max_current_length = 0
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
# start_slots = []
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
@ -491,30 +486,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
|
|
||||||
start_slots = []
|
|
||||||
slots = []
|
|
||||||
slot_indices = []
|
|
||||||
cumulative_slot_tokens = 0
|
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
requests_idx_mapping[request_id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
|
|
||||||
request = self.requests[idx]
|
requests.append(self.requests[idx])
|
||||||
requests.append(request)
|
|
||||||
|
|
||||||
# Prefilling
|
# Prefilling
|
||||||
request_prefilling = self.prefilling_mask[idx]
|
request_prefilling = self.prefilling_mask[idx]
|
||||||
prefilling_mask.append(request_prefilling)
|
prefilling_mask.append(request_prefilling)
|
||||||
|
|
||||||
# Input ids if the request was part of a prefilling batch
|
|
||||||
# If the batch was decoding we can index into the tensor directly later
|
|
||||||
if self.prefilling:
|
|
||||||
input_ids.append(self.input_ids[idx])
|
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_prompt_length = self.prompt_lengths[idx]
|
|
||||||
request_postfix_length = self.postfix_lengths[idx]
|
request_postfix_length = self.postfix_lengths[idx]
|
||||||
request_prefix_length = self.prefix_lengths[idx]
|
request_prefix_length = self.prefix_lengths[idx]
|
||||||
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
||||||
@ -525,7 +508,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
prompt_lengths.append(request_prompt_length)
|
prompt_lengths.append(self.prompt_lengths[idx])
|
||||||
postfix_lengths.append(request_postfix_length)
|
postfix_lengths.append(request_postfix_length)
|
||||||
prefix_lengths.append(request_prefix_length)
|
prefix_lengths.append(request_prefix_length)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
@ -541,45 +524,31 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||||
adapter_set.add(adapter_index)
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
# remaining_tokens = (
|
|
||||||
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
|
||||||
# )
|
|
||||||
|
|
||||||
request_block_table = self.block_tables[idx]
|
request_block_table = self.block_tables[idx]
|
||||||
num_blocks += len(request_block_table)
|
num_blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
# start_slots.append(cumulative_max_length)
|
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Input ids if the request was part of a prefilling batch
|
||||||
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
# If the batch was decoding we can index into the tensor directly later
|
||||||
|
if self.prefilling:
|
||||||
|
input_ids.append(self.input_ids[idx])
|
||||||
|
else:
|
||||||
|
# Copy to tensor (CPU)
|
||||||
|
slot_indices[i] = cumulative_max_length
|
||||||
|
|
||||||
# Set slice
|
remaining_tokens = (
|
||||||
#FIXME
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
# slot_filtering_indices[
|
)
|
||||||
# self.start_slots[idx] : self.start_slots[idx]
|
|
||||||
# + request_postfix_length
|
|
||||||
# + remaining_tokens
|
|
||||||
# - 1
|
|
||||||
# ] = True
|
|
||||||
|
|
||||||
if not self.prefilling:
|
# Set slice
|
||||||
if not request.slots:
|
slot_filtering_indices[
|
||||||
request_slots = [
|
self.slot_indices[idx] : self.slot_indices[idx]
|
||||||
s
|
+ request_postfix_length
|
||||||
for b in request_block_table
|
+ remaining_tokens
|
||||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
- 1
|
||||||
]
|
] = True
|
||||||
else:
|
|
||||||
request_slots = request.slots
|
|
||||||
|
|
||||||
request_slots = request_slots[request_prefix_length:]
|
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||||
start_slots.append(cumulative_slot_tokens)
|
|
||||||
slots.extend(request_slots)
|
|
||||||
slot_indices.append(cumulative_slot_tokens)
|
|
||||||
|
|
||||||
cumulative_slot_tokens += len(request_slots)
|
|
||||||
|
|
||||||
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
|
||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
@ -595,28 +564,22 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if self.prefilling:
|
if self.prefilling:
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids = None
|
position_ids = None
|
||||||
start_slots = None
|
|
||||||
slot_indices = None
|
slot_indices = None
|
||||||
slots = None
|
slots = None
|
||||||
prefix_lengths_tensor = None
|
prefix_lengths_tensor = None
|
||||||
postfix_lengths_tensor = None
|
postfix_lengths_tensor = None
|
||||||
adapter_meta = None
|
adapter_meta = None
|
||||||
else:
|
else:
|
||||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||||
# slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
# slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
adapter_segments = torch.tensor(
|
adapter_segments = torch.tensor(
|
||||||
@ -637,7 +600,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
@ -715,7 +677,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids = []
|
input_ids = []
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids = None
|
position_ids = None
|
||||||
start_slots = None
|
|
||||||
slots = None
|
slots = None
|
||||||
slot_indices = None
|
slot_indices = None
|
||||||
prefix_lengths_tensor = None
|
prefix_lengths_tensor = None
|
||||||
@ -725,7 +686,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else:
|
else:
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
start_slots = []
|
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||||
@ -836,8 +796,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
batch.prefix_lengths_tensor
|
batch.prefix_lengths_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots)
|
||||||
else:
|
else:
|
||||||
@ -867,11 +825,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
|
|
||||||
if start_slots is not None:
|
|
||||||
start_slots = torch.concat(start_slots)
|
|
||||||
|
|
||||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
@ -903,7 +856,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
@ -946,7 +898,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
sliding_window = get_sliding_windows()
|
sliding_window = get_sliding_windows()
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
start_slots = []
|
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
prefill_cache_indices = []
|
prefill_cache_indices = []
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
@ -1041,7 +992,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
start_slots.append(cumulative_slot_tokens)
|
|
||||||
slots.extend(request_slots)
|
slots.extend(request_slots)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
@ -1058,7 +1008,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_slot_tokens += len(request_slots)
|
cumulative_slot_tokens += len(request_slots)
|
||||||
|
|
||||||
device = self.block_tables_tensor.device
|
device = self.block_tables_tensor.device
|
||||||
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
|
||||||
|
|
||||||
if isinstance(self.input_ids, list):
|
if isinstance(self.input_ids, list):
|
||||||
if len(self) > 1:
|
if len(self) > 1:
|
||||||
@ -1762,6 +1711,8 @@ class FlashCausalLM(Model):
|
|||||||
if prefill:
|
if prefill:
|
||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
|
|
||||||
|
log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}")
|
||||||
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
Loading…
Reference in New Issue
Block a user