From 4db5e7dde6a91f0e8182011368afaa05798885ed Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:10:33 +0200 Subject: [PATCH] re-create slots --- .../models/flash_causal_lm.py | 64 +++++++++++++------ 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b39fe0ff..cf2b6ea7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -462,12 +462,12 @@ class FlashCausalLMBatch(Batch): ) # Create on CPU to only move to GPU once instead of at every copy - slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + # slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_postfix_length = 0 max_current_length = 0 requests = [] - start_slots = [] + # start_slots = [] block_tables = [] all_input_ids = [] prefix_ids = [] @@ -491,12 +491,18 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_max_length = 0 + start_slots = [] + slots = [] + slot_indices = [] + cumulative_slot_tokens = 0 + for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i - requests.append(self.requests[idx]) + request = self.requests[idx] + requests.append(request) # Prefilling request_prefilling = self.prefilling_mask[idx] @@ -508,6 +514,7 @@ class FlashCausalLMBatch(Batch): input_ids.append(self.input_ids[idx]) # Get length + request_prompt_length = self.prompt_lengths[idx] request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] max_postfix_length = max(max_postfix_length, request_postfix_length) @@ -518,7 +525,7 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - prompt_lengths.append(self.prompt_lengths[idx]) + prompt_lengths.append(request_prompt_length) postfix_lengths.append(request_postfix_length) prefix_lengths.append(request_prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) @@ -534,27 +541,45 @@ class FlashCausalLMBatch(Batch): adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_set.add(adapter_index) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) + # remaining_tokens = ( + # stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + # ) request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) + # start_slots.append(cumulative_max_length) # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_postfix_length - 1 + # slot_indices[i] = cumulative_max_length + request_postfix_length - 1 # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_postfix_length - + remaining_tokens - - 1 - ] = True + #FIXME + # slot_filtering_indices[ + # self.start_slots[idx] : self.start_slots[idx] + # + request_postfix_length + # + remaining_tokens + # - 1 + # ] = True - cumulative_max_length += request_postfix_length + remaining_tokens - 1 + if not self.prefilling: + if not request.slots: + request_slots = [ + s + for b in request_block_table + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = request.slots + + request_slots = request_slots[request_prefix_length:] + start_slots.append(cumulative_slot_tokens) + slots.extend(request_slots) + slot_indices.append(cumulative_slot_tokens) + + cumulative_slot_tokens += len(request_slots) + + # cumulative_max_length += request_postfix_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -577,18 +602,21 @@ class FlashCausalLMBatch(Batch): postfix_lengths_tensor = None adapter_meta = None else: + slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) + slots = torch.tensor(slots, dtype=torch.int64, device=device) + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] + # slots = self.slots[slot_filtering_indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices] start_slots = torch.tensor(start_slots, dtype=torch.int64) # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + # slot_indices = slot_indices.to(device) adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments = torch.tensor(