fix slot_filtering_indices

This commit is contained in:
OlivierDehaene 2024-10-02 19:16:36 +02:00
parent b49978ff67
commit ff4155dfea
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 25 additions and 75 deletions

View File

@ -175,7 +175,6 @@ pub(crate) async fn batching_task(
let (min_size, max_size, prefill_token_budget) = if support_chunking { let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch, // Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget // the current batch tokens must be subtracted to the prefill budget
// In the future, we could concatenate beforehand
let prefill_token_budget = max_batch_prefill_tokens - current_tokens; let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
// We can ignore min_size and max_size // We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking // Models than rely on max_size cannot support chunking

View File

@ -138,9 +138,6 @@ class FlashCausalLMBatch(Batch):
speculative_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor]
# Set when creating the batch # Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
start_slots: Optional[torch.Tensor]
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor] slot_indices: Optional[torch.Tensor]
@ -417,7 +414,6 @@ class FlashCausalLMBatch(Batch):
position_ids=None, position_ids=None,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
start_slots=None,
slot_indices=None, slot_indices=None,
slots=None, slots=None,
prefill_head_indices=None, prefill_head_indices=None,
@ -462,12 +458,11 @@ class FlashCausalLMBatch(Batch):
) )
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64) slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_postfix_length = 0 max_postfix_length = 0
max_current_length = 0 max_current_length = 0
requests = [] requests = []
# start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
@ -491,30 +486,18 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_max_length = 0 cumulative_max_length = 0
start_slots = []
slots = []
slot_indices = []
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
indices.append(idx) indices.append(idx)
requests_idx_mapping[request_id] = i requests_idx_mapping[request_id] = i
request = self.requests[idx] requests.append(self.requests[idx])
requests.append(request)
# Prefilling # Prefilling
request_prefilling = self.prefilling_mask[idx] request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling) prefilling_mask.append(request_prefilling)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
# Get length # Get length
request_prompt_length = self.prompt_lengths[idx]
request_postfix_length = self.postfix_lengths[idx] request_postfix_length = self.postfix_lengths[idx]
request_prefix_length = self.prefix_lengths[idx] request_prefix_length = self.prefix_lengths[idx]
max_postfix_length = max(max_postfix_length, request_postfix_length) max_postfix_length = max(max_postfix_length, request_postfix_length)
@ -525,7 +508,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx]) prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(request_prompt_length) prompt_lengths.append(self.prompt_lengths[idx])
postfix_lengths.append(request_postfix_length) postfix_lengths.append(request_postfix_length)
prefix_lengths.append(request_prefix_length) prefix_lengths.append(request_prefix_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
@ -541,45 +524,31 @@ class FlashCausalLMBatch(Batch):
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
# remaining_tokens = (
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
# )
request_block_table = self.block_tables[idx] request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table) num_blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
# start_slots.append(cumulative_max_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU) # Copy to tensor (CPU)
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1 slot_indices[i] = cumulative_max_length
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Set slice # Set slice
#FIXME slot_filtering_indices[
# slot_filtering_indices[ self.slot_indices[idx] : self.slot_indices[idx]
# self.start_slots[idx] : self.start_slots[idx] + request_postfix_length
# + request_postfix_length + remaining_tokens
# + remaining_tokens - 1
# - 1 ] = True
# ] = True
if not self.prefilling: cumulative_max_length += request_postfix_length + remaining_tokens - 1
if not request.slots:
request_slots = [
s
for b in request_block_table
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = request.slots
request_slots = request_slots[request_prefix_length:]
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(cumulative_slot_tokens)
cumulative_slot_tokens += len(request_slots)
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
@ -595,28 +564,22 @@ class FlashCausalLMBatch(Batch):
if self.prefilling: if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill` # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None position_ids = None
start_slots = None
slot_indices = None slot_indices = None
slots = None slots = None
prefix_lengths_tensor = None prefix_lengths_tensor = None
postfix_lengths_tensor = None postfix_lengths_tensor = None
adapter_meta = None adapter_meta = None
else: else:
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices] adapter_indices = self.adapter_meta.adapter_indices[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
# slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
# slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor( adapter_segments = torch.tensor(
@ -637,7 +600,6 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
@ -715,7 +677,6 @@ class FlashCausalLMBatch(Batch):
input_ids = [] input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill` # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None position_ids = None
start_slots = None
slots = None slots = None
slot_indices = None slot_indices = None
prefix_lengths_tensor = None prefix_lengths_tensor = None
@ -725,7 +686,6 @@ class FlashCausalLMBatch(Batch):
else: else:
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_slots = []
slots = batches[0].slots.new_empty(total_slots) slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
@ -836,8 +796,6 @@ class FlashCausalLMBatch(Batch):
batch.prefix_lengths_tensor batch.prefix_lengths_tensor
) )
start_slots.append(batch.start_slots + cumulative_slots)
# Update # Update
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
else: else:
@ -867,11 +825,6 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
if start_slots is not None:
start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
@ -903,7 +856,6 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
@ -946,7 +898,6 @@ class FlashCausalLMBatch(Batch):
sliding_window = get_sliding_windows() sliding_window = get_sliding_windows()
position_ids = [] position_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
start_slots = []
slot_indices = [] slot_indices = []
prefill_cache_indices = [] prefill_cache_indices = []
all_prefill_logprobs = True all_prefill_logprobs = True
@ -1041,7 +992,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots) slots.extend(request_slots)
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -1058,7 +1008,6 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens += len(request_slots) cumulative_slot_tokens += len(request_slots)
device = self.block_tables_tensor.device device = self.block_tables_tensor.device
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
if isinstance(self.input_ids, list): if isinstance(self.input_ids, list):
if len(self) > 1: if len(self) > 1:
@ -1762,6 +1711,8 @@ class FlashCausalLM(Model):
if prefill: if prefill:
batch.prepare_for_prefill() batch.prepare_for_prefill()
log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}")
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)