re-create slots

This commit is contained in:
OlivierDehaene 2024-10-02 14:10:33 +02:00
parent 7f9abde3f8
commit 4db5e7dde6
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -462,12 +462,12 @@ class FlashCausalLMBatch(Batch):
)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_postfix_length = 0
max_current_length = 0
requests = []
start_slots = []
# start_slots = []
block_tables = []
all_input_ids = []
prefix_ids = []
@ -491,12 +491,18 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_max_length = 0
start_slots = []
slots = []
slot_indices = []
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
request = self.requests[idx]
requests.append(request)
# Prefilling
request_prefilling = self.prefilling_mask[idx]
@ -508,6 +514,7 @@ class FlashCausalLMBatch(Batch):
input_ids.append(self.input_ids[idx])
# Get length
request_prompt_length = self.prompt_lengths[idx]
request_postfix_length = self.postfix_lengths[idx]
request_prefix_length = self.prefix_lengths[idx]
max_postfix_length = max(max_postfix_length, request_postfix_length)
@ -518,7 +525,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
prompt_lengths.append(request_prompt_length)
postfix_lengths.append(request_postfix_length)
prefix_lengths.append(request_prefix_length)
prefix_offsets.append(self.prefix_offsets[idx])
@ -534,27 +541,45 @@ class FlashCausalLMBatch(Batch):
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index)
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# remaining_tokens = (
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
# )
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# start_slots.append(cumulative_max_length)
# Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length + request_postfix_length - 1
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1
# Set slice
slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx]
+ request_postfix_length
+ remaining_tokens
- 1
] = True
#FIXME
# slot_filtering_indices[
# self.start_slots[idx] : self.start_slots[idx]
# + request_postfix_length
# + remaining_tokens
# - 1
# ] = True
cumulative_max_length += request_postfix_length + remaining_tokens - 1
if not self.prefilling:
if not request.slots:
request_slots = [
s
for b in request_block_table
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = request.slots
request_slots = request_slots[request_prefix_length:]
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(cumulative_slot_tokens)
cumulative_slot_tokens += len(request_slots)
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
@ -577,18 +602,21 @@ class FlashCausalLMBatch(Batch):
postfix_lengths_tensor = None
adapter_meta = None
else:
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
# slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
# slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(