diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 183f4e52..84152ff8 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -175,7 +175,6 @@ pub(crate) async fn batching_task( let (min_size, max_size, prefill_token_budget) = if support_chunking { // Since the next batch will be concatenated with the current batch, // the current batch tokens must be subtracted to the prefill budget - // In the future, we could concatenate beforehand let prefill_token_budget = max_batch_prefill_tokens - current_tokens; // We can ignore min_size and max_size // Models than rely on max_size cannot support chunking diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cf2b6ea7..b283a5fb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -138,9 +138,6 @@ class FlashCausalLMBatch(Batch): speculative_ids: Optional[torch.Tensor] # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - # Will be set by `generate_token` and reset after each prefill forward before staying set in decode - start_slots: Optional[torch.Tensor] # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode slot_indices: Optional[torch.Tensor] @@ -417,7 +414,6 @@ class FlashCausalLMBatch(Batch): position_ids=None, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=None, slot_indices=None, slots=None, prefill_head_indices=None, @@ -462,12 +458,11 @@ class FlashCausalLMBatch(Batch): ) # Create on CPU to only move to GPU once instead of at every copy - # slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_postfix_length = 0 max_current_length = 0 requests = [] - # start_slots = [] block_tables = [] all_input_ids = [] prefix_ids = [] @@ -491,30 +486,18 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_max_length = 0 - start_slots = [] - slots = [] - slot_indices = [] - cumulative_slot_tokens = 0 - for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i - request = self.requests[idx] - requests.append(request) + requests.append(self.requests[idx]) # Prefilling request_prefilling = self.prefilling_mask[idx] prefilling_mask.append(request_prefilling) - # Input ids if the request was part of a prefilling batch - # If the batch was decoding we can index into the tensor directly later - if self.prefilling: - input_ids.append(self.input_ids[idx]) - # Get length - request_prompt_length = self.prompt_lengths[idx] request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] max_postfix_length = max(max_postfix_length, request_postfix_length) @@ -525,7 +508,7 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - prompt_lengths.append(request_prompt_length) + prompt_lengths.append(self.prompt_lengths[idx]) postfix_lengths.append(request_postfix_length) prefix_lengths.append(request_prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) @@ -541,45 +524,31 @@ class FlashCausalLMBatch(Batch): adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - # remaining_tokens = ( - # stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - # ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - # start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - # slot_indices[i] = cumulative_max_length + request_postfix_length - 1 + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + else: + # Copy to tensor (CPU) + slot_indices[i] = cumulative_max_length - # Set slice - #FIXME - # slot_filtering_indices[ - # self.start_slots[idx] : self.start_slots[idx] - # + request_postfix_length - # + remaining_tokens - # - 1 - # ] = True + remaining_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) - if not self.prefilling: - if not request.slots: - request_slots = [ - s - for b in request_block_table - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] - else: - request_slots = request.slots + # Set slice + slot_filtering_indices[ + self.slot_indices[idx] : self.slot_indices[idx] + + request_postfix_length + + remaining_tokens + - 1 + ] = True - request_slots = request_slots[request_prefix_length:] - start_slots.append(cumulative_slot_tokens) - slots.extend(request_slots) - slot_indices.append(cumulative_slot_tokens) - - cumulative_slot_tokens += len(request_slots) - - # cumulative_max_length += request_postfix_length + remaining_tokens - 1 + cumulative_max_length += request_postfix_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -595,28 +564,22 @@ class FlashCausalLMBatch(Batch): if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - start_slots = None slot_indices = None slots = None prefix_lengths_tensor = None postfix_lengths_tensor = None adapter_meta = None else: - slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices] - # slots = self.slots[slot_filtering_indices] + slots = self.slots[slot_filtering_indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) - # Move to GPU now that we have the whole tensor - # slot_indices = slot_indices.to(device) + slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor( @@ -637,7 +600,6 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -715,7 +677,6 @@ class FlashCausalLMBatch(Batch): input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - start_slots = None slots = None slot_indices = None prefix_lengths_tensor = None @@ -725,7 +686,6 @@ class FlashCausalLMBatch(Batch): else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - start_slots = [] slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( @@ -836,8 +796,6 @@ class FlashCausalLMBatch(Batch): batch.prefix_lengths_tensor ) - start_slots.append(batch.start_slots + cumulative_slots) - # Update cumulative_slots += len(batch.slots) else: @@ -867,11 +825,6 @@ class FlashCausalLMBatch(Batch): # Update cumulative_batch_size += len(batch) - if start_slots is not None: - start_slots = torch.concat(start_slots) - - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, @@ -903,7 +856,6 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -946,7 +898,6 @@ class FlashCausalLMBatch(Batch): sliding_window = get_sliding_windows() position_ids = [] cu_seqlen_prefill = [0] - start_slots = [] slot_indices = [] prefill_cache_indices = [] all_prefill_logprobs = True @@ -1041,7 +992,6 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - start_slots.append(cumulative_slot_tokens) slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -1058,7 +1008,6 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens += len(request_slots) device = self.block_tables_tensor.device - self.start_slots = torch.tensor(start_slots, dtype=torch.int64) if isinstance(self.input_ids, list): if len(self) > 1: @@ -1762,6 +1711,8 @@ class FlashCausalLM(Model): if prefill: batch.prepare_for_prefill() + log_master(logger.info, f"Tokens in this forward: {len(batch.input_ids)}") + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present)