mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
re-create slots
This commit is contained in:
parent
7f9abde3f8
commit
4db5e7dde6
@ -462,12 +462,12 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
max_postfix_length = 0
|
||||
max_current_length = 0
|
||||
|
||||
requests = []
|
||||
start_slots = []
|
||||
# start_slots = []
|
||||
block_tables = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
@ -491,12 +491,18 @@ class FlashCausalLMBatch(Batch):
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
|
||||
start_slots = []
|
||||
slots = []
|
||||
slot_indices = []
|
||||
cumulative_slot_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
requests_idx_mapping[request_id] = i
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
request = self.requests[idx]
|
||||
requests.append(request)
|
||||
|
||||
# Prefilling
|
||||
request_prefilling = self.prefilling_mask[idx]
|
||||
@ -508,6 +514,7 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids.append(self.input_ids[idx])
|
||||
|
||||
# Get length
|
||||
request_prompt_length = self.prompt_lengths[idx]
|
||||
request_postfix_length = self.postfix_lengths[idx]
|
||||
request_prefix_length = self.prefix_lengths[idx]
|
||||
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
||||
@ -518,7 +525,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
prompt_lengths.append(self.prompt_lengths[idx])
|
||||
prompt_lengths.append(request_prompt_length)
|
||||
postfix_lengths.append(request_postfix_length)
|
||||
prefix_lengths.append(request_prefix_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
@ -534,27 +541,45 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
# remaining_tokens = (
|
||||
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
# )
|
||||
|
||||
request_block_table = self.block_tables[idx]
|
||||
num_blocks += len(request_block_table)
|
||||
block_tables.append(request_block_table)
|
||||
start_slots.append(cumulative_max_length)
|
||||
# start_slots.append(cumulative_max_length)
|
||||
|
||||
# Copy to tensor (CPU)
|
||||
slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
||||
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
||||
|
||||
# Set slice
|
||||
slot_filtering_indices[
|
||||
self.start_slots[idx] : self.start_slots[idx]
|
||||
+ request_postfix_length
|
||||
+ remaining_tokens
|
||||
- 1
|
||||
] = True
|
||||
#FIXME
|
||||
# slot_filtering_indices[
|
||||
# self.start_slots[idx] : self.start_slots[idx]
|
||||
# + request_postfix_length
|
||||
# + remaining_tokens
|
||||
# - 1
|
||||
# ] = True
|
||||
|
||||
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||
if not self.prefilling:
|
||||
if not request.slots:
|
||||
request_slots = [
|
||||
s
|
||||
for b in request_block_table
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_slots = request.slots
|
||||
|
||||
request_slots = request_slots[request_prefix_length:]
|
||||
start_slots.append(cumulative_slot_tokens)
|
||||
slots.extend(request_slots)
|
||||
slot_indices.append(cumulative_slot_tokens)
|
||||
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
@ -577,18 +602,21 @@ class FlashCausalLMBatch(Batch):
|
||||
postfix_lengths_tensor = None
|
||||
adapter_meta = None
|
||||
else:
|
||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
# slots = self.slots[slot_filtering_indices]
|
||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
# slot_indices = slot_indices.to(device)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
|
Loading…
Reference in New Issue
Block a user