mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
re-create slots
This commit is contained in:
parent
7f9abde3f8
commit
4db5e7dde6
@ -462,12 +462,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
# slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
max_postfix_length = 0
|
max_postfix_length = 0
|
||||||
max_current_length = 0
|
max_current_length = 0
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
start_slots = []
|
# start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
@ -491,12 +491,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
|
|
||||||
|
start_slots = []
|
||||||
|
slots = []
|
||||||
|
slot_indices = []
|
||||||
|
cumulative_slot_tokens = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
requests_idx_mapping[request_id] = i
|
requests_idx_mapping[request_id] = i
|
||||||
|
|
||||||
requests.append(self.requests[idx])
|
request = self.requests[idx]
|
||||||
|
requests.append(request)
|
||||||
|
|
||||||
# Prefilling
|
# Prefilling
|
||||||
request_prefilling = self.prefilling_mask[idx]
|
request_prefilling = self.prefilling_mask[idx]
|
||||||
@ -508,6 +514,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids.append(self.input_ids[idx])
|
input_ids.append(self.input_ids[idx])
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
|
request_prompt_length = self.prompt_lengths[idx]
|
||||||
request_postfix_length = self.postfix_lengths[idx]
|
request_postfix_length = self.postfix_lengths[idx]
|
||||||
request_prefix_length = self.prefix_lengths[idx]
|
request_prefix_length = self.prefix_lengths[idx]
|
||||||
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
||||||
@ -518,7 +525,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
prompt_lengths.append(self.prompt_lengths[idx])
|
prompt_lengths.append(request_prompt_length)
|
||||||
postfix_lengths.append(request_postfix_length)
|
postfix_lengths.append(request_postfix_length)
|
||||||
prefix_lengths.append(request_prefix_length)
|
prefix_lengths.append(request_prefix_length)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
@ -534,27 +541,45 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||||
adapter_set.add(adapter_index)
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
remaining_tokens = (
|
# remaining_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
# stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
# )
|
||||||
|
|
||||||
request_block_table = self.block_tables[idx]
|
request_block_table = self.block_tables[idx]
|
||||||
num_blocks += len(request_block_table)
|
num_blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
start_slots.append(cumulative_max_length)
|
# start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
# slot_indices[i] = cumulative_max_length + request_postfix_length - 1
|
||||||
|
|
||||||
# Set slice
|
# Set slice
|
||||||
slot_filtering_indices[
|
#FIXME
|
||||||
self.start_slots[idx] : self.start_slots[idx]
|
# slot_filtering_indices[
|
||||||
+ request_postfix_length
|
# self.start_slots[idx] : self.start_slots[idx]
|
||||||
+ remaining_tokens
|
# + request_postfix_length
|
||||||
- 1
|
# + remaining_tokens
|
||||||
] = True
|
# - 1
|
||||||
|
# ] = True
|
||||||
|
|
||||||
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
if not self.prefilling:
|
||||||
|
if not request.slots:
|
||||||
|
request_slots = [
|
||||||
|
s
|
||||||
|
for b in request_block_table
|
||||||
|
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
request_slots = request.slots
|
||||||
|
|
||||||
|
request_slots = request_slots[request_prefix_length:]
|
||||||
|
start_slots.append(cumulative_slot_tokens)
|
||||||
|
slots.extend(request_slots)
|
||||||
|
slot_indices.append(cumulative_slot_tokens)
|
||||||
|
|
||||||
|
cumulative_slot_tokens += len(request_slots)
|
||||||
|
|
||||||
|
# cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
@ -577,18 +602,21 @@ class FlashCausalLMBatch(Batch):
|
|||||||
postfix_lengths_tensor = None
|
postfix_lengths_tensor = None
|
||||||
adapter_meta = None
|
adapter_meta = None
|
||||||
else:
|
else:
|
||||||
|
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||||
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
# slots = self.slots[slot_filtering_indices]
|
||||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
slot_indices = slot_indices.to(device)
|
# slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
adapter_segments = torch.tensor(
|
adapter_segments = torch.tensor(
|
||||||
|
Loading…
Reference in New Issue
Block a user