re-create slots

This commit is contained in:
OlivierDehaene 2024-10-02 14:10:33 +02:00
parent 7f9abde3f8
commit 4db5e7dde6
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -462,12 +462,12 @@ class FlashCausalLMBatch(Batch):
) )
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64) # slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_postfix_length = 0 max_postfix_length = 0
max_current_length = 0 max_current_length = 0
requests = [] requests = []
start_slots = [] # start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
@ -491,12 +491,18 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_max_length = 0 cumulative_max_length = 0
start_slots = []
slots = []
slot_indices = []
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
indices.append(idx) indices.append(idx)
requests_idx_mapping[request_id] = i requests_idx_mapping[request_id] = i
requests.append(self.requests[idx]) request = self.requests[idx]
requests.append(request)
# Prefilling # Prefilling
request_prefilling = self.prefilling_mask[idx] request_prefilling = self.prefilling_mask[idx]
@ -508,6 +514,7 @@ class FlashCausalLMBatch(Batch):
input_ids.append(self.input_ids[idx]) input_ids.append(self.input_ids[idx])
# Get length # Get length
request_prompt_length = self.prompt_lengths[idx]
request_postfix_length = self.postfix_lengths[idx] request_postfix_length = self.postfix_lengths[idx]
request_prefix_length = self.prefix_lengths[idx] request_prefix_length = self.prefix_lengths[idx]
max_postfix_length = max(max_postfix_length, request_postfix_length) max_postfix_length = max(max_postfix_length, request_postfix_length)
@ -518,7 +525,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx]) prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx]) prompt_lengths.append(request_prompt_length)
postfix_lengths.append(request_postfix_length) postfix_lengths.append(request_postfix_length)
prefix_lengths.append(request_prefix_length) prefix_lengths.append(request_prefix_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
@ -534,27 +541,45 @@ class FlashCausalLMBatch(Batch):
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
remaining_tokens = ( # remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens # stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) # )
request_block_table = self.block_tables[idx] request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table) num_blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
start_slots.append(cumulative_max_length) # start_slots.append(cumulative_max_length)
# Copy to tensor (CPU) # Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length + request_postfix_length - 1 # slot_indices[i] = cumulative_max_length + request_postfix_length - 1
# Set slice # Set slice
slot_filtering_indices[ #FIXME
self.start_slots[idx] : self.start_slots[idx] # slot_filtering_indices[
+ request_postfix_length # self.start_slots[idx] : self.start_slots[idx]
+ remaining_tokens # + request_postfix_length
- 1 # + remaining_tokens
] = True # - 1
# ] = True
cumulative_max_length += request_postfix_length + remaining_tokens - 1 if not self.prefilling:
if not request.slots:
request_slots = [
s
for b in request_block_table
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = request.slots
request_slots = request_slots[request_prefix_length:]
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(cumulative_slot_tokens)
cumulative_slot_tokens += len(request_slots)
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
@ -577,18 +602,21 @@ class FlashCausalLMBatch(Batch):
postfix_lengths_tensor = None postfix_lengths_tensor = None
adapter_meta = None adapter_meta = None
else: else:
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices] adapter_indices = self.adapter_meta.adapter_indices[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] # slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) # slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor( adapter_segments = torch.tensor(