mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix slot_filtering_indices
This commit is contained in:
parent
b49978ff67
commit
ff4155dfea
@ -175,7 +175,6 @@ pub(crate) async fn batching_task(
|
||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||
// Since the next batch will be concatenated with the current batch,
|
||||
// the current batch tokens must be subtracted to the prefill budget
|
||||
// In the future, we could concatenate beforehand
|
||||
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
||||
// We can ignore min_size and max_size
|
||||
// Models than rely on max_size cannot support chunking
|
||||
|
@ -138,9 +138,6 @@ class FlashCausalLMBatch(Batch):
|
||||
speculative_ids: Optional[torch.Tensor]
|
||||
|
||||
# Set when creating the batch
|
||||
# CPU tensor of length b indicating the start of each sequence in slots
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
start_slots: Optional[torch.Tensor]
|
||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
slot_indices: Optional[torch.Tensor]
|
||||
@ -417,7 +414,6 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids=None,
|
||||
cu_seqlen_prefill=None,
|
||||
prefill_cache_indices=None,
|
||||
start_slots=None,
|
||||
slot_indices=None,
|
||||
slots=None,
|
||||
prefill_head_indices=None,
|
||||
@ -462,12 +458,11 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
max_postfix_length = 0
|
||||
max_current_length = 0
|
||||
|
||||
requests = []
|
||||
# start_slots = []
|
||||
block_tables = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
@ -491,30 +486,18 @@ class FlashCausalLMBatch(Batch):
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
|
||||
start_slots = []
|
||||
slots = []
|
||||
slot_indices = []
|
||||
cumulative_slot_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
requests_idx_mapping[request_id] = i
|
||||
|
||||
request = self.requests[idx]
|
||||
requests.append(request)
|
||||
requests.append(self.requests[idx])
|
||||
|
||||
# Prefilling
|
||||
request_prefilling = self.prefilling_mask[idx]
|
||||
prefilling_mask.append(request_prefilling)
|
||||
|
||||
# Input ids if the request was part of a prefilling batch
|
||||
# If the batch was decoding we can index into the tensor directly later
|
||||
if self.prefilling:
|
||||
input_ids.append(self.input_ids[idx])
|
||||
|
||||
# Get length
|
||||
request_prompt_length = self.prompt_lengths[idx]
|
||||
request_postfix_length = self.postfix_lengths[idx]
|
||||
request_prefix_length = self.prefix_lengths[idx]
|
||||
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
||||
@ -525,7 +508,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
prompt_lengths.append(request_prompt_length)
|
||||
prompt_lengths.append(self.prompt_lengths[idx])
|
||||
postfix_lengths.append(request_postfix_length)
|
||||
prefix_lengths.append(request_prefix_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
@ -541,45 +524,31 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
# remaining_tokens = (
|
||||
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
# )
|
||||
|
||||
request_block_table = self.block_tables[idx]
|
||||
num_blocks += len(request_block_table)
|
||||
block_tables.append(request_block_table)
|
||||
# start_slots.append(cumulative_max_length)
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
||||
# Input ids if the request was part of a prefilling batch
|
||||
# If the batch was decoding we can index into the tensor directly later
|
||||
if self.prefilling:
|
||||
input_ids.append(self.input_ids[idx])
|
||||
else:
|
||||
# Copy to tensor (CPU)
|
||||
slot_indices[i] = cumulative_max_length
|
||||
|
||||
# Set slice
|
||||
#FIXME
|
||||
# slot_filtering_indices[
|
||||
# self.start_slots[idx] : self.start_slots[idx]
|
||||
# + request_postfix_length
|
||||
# + remaining_tokens
|
||||
# - 1
|
||||
# ] = True
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
||||
if not self.prefilling:
|
||||
if not request.slots:
|
||||
request_slots = [
|
||||
s
|
||||
for b in request_block_table
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_slots = request.slots
|
||||
# Set slice
|
||||
slot_filtering_indices[
|
||||
self.slot_indices[idx] : self.slot_indices[idx]
|
||||
+ request_postfix_length
|
||||
+ remaining_tokens
|
||||
- 1
|
||||
] = True
|
||||
|
||||
request_slots = request_slots[request_prefix_length:]
|
||||
start_slots.append(cumulative_slot_tokens)
|
||||
slots.extend(request_slots)
|
||||
slot_indices.append(cumulative_slot_tokens)
|
||||
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
@ -595,28 +564,22 @@ class FlashCausalLMBatch(Batch):
|
||||
if self.prefilling:
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids = None
|
||||
start_slots = None
|
||||
slot_indices = None
|
||||
slots = None
|
||||
prefix_lengths_tensor = None
|
||||
postfix_lengths_tensor = None
|
||||
adapter_meta = None
|
||||
else:
|
||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||
# slots = self.slots[slot_filtering_indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
# slot_indices = slot_indices.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
@ -637,7 +600,6 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
prefill_cache_indices=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
@ -715,7 +677,6 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids = []
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids = None
|
||||
start_slots = None
|
||||
slots = None
|
||||
slot_indices = None
|
||||
prefix_lengths_tensor = None
|
||||
@ -725,7 +686,6 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
start_slots = []
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||
@ -836,8 +796,6 @@ class FlashCausalLMBatch(Batch):
|
||||
batch.prefix_lengths_tensor
|
||||
)
|
||||
|
||||
start_slots.append(batch.start_slots + cumulative_slots)
|
||||
|
||||
# Update
|
||||
cumulative_slots += len(batch.slots)
|
||||
else:
|
||||
@ -867,11 +825,6 @@ class FlashCausalLMBatch(Batch):
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
if start_slots is not None:
|
||||
start_slots = torch.concat(start_slots)
|
||||
|
||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters,
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
@ -903,7 +856,6 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
prefill_cache_indices=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
@ -946,7 +898,6 @@ class FlashCausalLMBatch(Batch):
|
||||
sliding_window = get_sliding_windows()
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
all_prefill_logprobs = True
|
||||
@ -1041,7 +992,6 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
start_slots.append(cumulative_slot_tokens)
|
||||
slots.extend(request_slots)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
@ -1058,7 +1008,6 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
device = self.block_tables_tensor.device
|
||||
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
if isinstance(self.input_ids, list):
|
||||
if len(self) > 1:
|
||||
@ -1762,6 +1711,8 @@ class FlashCausalLM(Model):
|
||||
if prefill:
|
||||
batch.prepare_for_prefill()
|
||||
|
||||
log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}")
|
||||
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
|
Loading…
Reference in New Issue
Block a user