fix slot_filtering_indices

This commit is contained in:
OlivierDehaene 2024-10-02 19:16:36 +02:00
parent b49978ff67
commit ff4155dfea
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 25 additions and 75 deletions

View File

@ -175,7 +175,6 @@ pub(crate) async fn batching_task(
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
// In the future, we could concatenate beforehand
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking

View File

@ -138,9 +138,6 @@ class FlashCausalLMBatch(Batch):
speculative_ids: Optional[torch.Tensor]
# Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
start_slots: Optional[torch.Tensor]
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor]
@ -417,7 +414,6 @@ class FlashCausalLMBatch(Batch):
position_ids=None,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=None,
slot_indices=None,
slots=None,
prefill_head_indices=None,
@ -462,12 +458,11 @@ class FlashCausalLMBatch(Batch):
)
# Create on CPU to only move to GPU once instead of at every copy
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_postfix_length = 0
max_current_length = 0
requests = []
# start_slots = []
block_tables = []
all_input_ids = []
prefix_ids = []
@ -491,30 +486,18 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_max_length = 0
start_slots = []
slots = []
slot_indices = []
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
request = self.requests[idx]
requests.append(request)
requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
# Get length
request_prompt_length = self.prompt_lengths[idx]
request_postfix_length = self.postfix_lengths[idx]
request_prefix_length = self.prefix_lengths[idx]
max_postfix_length = max(max_postfix_length, request_postfix_length)
@ -525,7 +508,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(request_prompt_length)
prompt_lengths.append(self.prompt_lengths[idx])
postfix_lengths.append(request_postfix_length)
prefix_lengths.append(request_prefix_length)
prefix_offsets.append(self.prefix_offsets[idx])
@ -541,45 +524,31 @@ class FlashCausalLMBatch(Batch):
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index)
# remaining_tokens = (
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
# )
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
# start_slots.append(cumulative_max_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1
slot_indices[i] = cumulative_max_length
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Set slice
#FIXME
# slot_filtering_indices[
# self.start_slots[idx] : self.start_slots[idx]
# + request_postfix_length
# + remaining_tokens
# - 1
# ] = True
slot_filtering_indices[
self.slot_indices[idx] : self.slot_indices[idx]
+ request_postfix_length
+ remaining_tokens
- 1
] = True
if not self.prefilling:
if not request.slots:
request_slots = [
s
for b in request_block_table
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = request.slots
request_slots = request_slots[request_prefix_length:]
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(cumulative_slot_tokens)
cumulative_slot_tokens += len(request_slots)
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
cumulative_max_length += request_postfix_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
@ -595,28 +564,22 @@ class FlashCausalLMBatch(Batch):
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
start_slots = None
slot_indices = None
slots = None
prefix_lengths_tensor = None
postfix_lengths_tensor = None
adapter_meta = None
else:
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
# slots = self.slots[slot_filtering_indices]
slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor
# slot_indices = slot_indices.to(device)
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
@ -637,7 +600,6 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
@ -715,7 +677,6 @@ class FlashCausalLMBatch(Batch):
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
start_slots = None
slots = None
slot_indices = None
prefix_lengths_tensor = None
@ -725,7 +686,6 @@ class FlashCausalLMBatch(Batch):
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_slots = []
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
@ -836,8 +796,6 @@ class FlashCausalLMBatch(Batch):
batch.prefix_lengths_tensor
)
start_slots.append(batch.start_slots + cumulative_slots)
# Update
cumulative_slots += len(batch.slots)
else:
@ -867,11 +825,6 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_batch_size += len(batch)
if start_slots is not None:
start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
@ -903,7 +856,6 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
@ -946,7 +898,6 @@ class FlashCausalLMBatch(Batch):
sliding_window = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
start_slots = []
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
@ -1041,7 +992,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(request_slot_indices)
@ -1058,7 +1008,6 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens += len(request_slots)
device = self.block_tables_tensor.device
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
if isinstance(self.input_ids, list):
if len(self) > 1:
@ -1762,6 +1711,8 @@ class FlashCausalLM(Model):
if prefill:
batch.prepare_for_prefill()
log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}")
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)