This commit is contained in:
OlivierDehaene 2024-10-01 09:51:34 +02:00
parent 173bc99ab3
commit 34f5dc525e
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -16,7 +16,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
GenerationConfig, GenerationConfig,
) )
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -119,7 +119,9 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping: Dict[int, int] requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor # Can be a list for easy filtering
# If `input_ids` is a list, it needs to be materialized to a tensor first
input_ids: Union[torch.Tensor, List[List[int]]]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
position_ids: Optional[torch.Tensor] position_ids: Optional[torch.Tensor]
speculative_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor]
@ -178,7 +180,7 @@ class FlashCausalLMBatch(Batch):
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
postfix_lengths_tensor: Optional[torch.Tensor] postfix_lengths_tensor: Optional[torch.Tensor]
prefix_lengths_tensor: Optional[torch.Tensor] prefix_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: Optional[torch.Tensor] prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
@ -350,12 +352,6 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
if len(pb.requests) > 1:
input_ids = np.concatenate(all_postfix_ids, dtype=np.int64)
else:
input_ids = all_postfix_ids[0]
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
top_n_tokens_tensor = torch.tensor( top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64 top_n_tokens, device=device, dtype=torch.int64
) )
@ -366,12 +362,15 @@ class FlashCausalLMBatch(Batch):
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prompt_lengths_tensor = torch.tensor(
prompt_lengths, dtype=torch.int32, device=device
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=all_postfix_ids,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
@ -395,6 +394,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None, speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill` # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None, position_ids=None,
@ -408,7 +408,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens=None, prefill_cu_outlens=None,
prefix_lengths_tensor=None, prefix_lengths_tensor=None,
postfix_lengths_tensor=None, postfix_lengths_tensor=None,
prompt_lengths_tensor=None,
adapter_meta=None, adapter_meta=None,
) )
@ -455,6 +454,7 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
input_ids = []
prompt_lengths = [] prompt_lengths = []
postfix_lengths = [] postfix_lengths = []
@ -473,7 +473,6 @@ class FlashCausalLMBatch(Batch):
max_blocks = 0 max_blocks = 0
# Cumulative length # Cumulative length
cumulative_max_length = 0 cumulative_max_length = 0
prefilling=False
for i, request_id in enumerate(request_ids): for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
@ -484,9 +483,13 @@ class FlashCausalLMBatch(Batch):
# Prefilling # Prefilling
request_prefilling = self.prefilling_mask[idx] request_prefilling = self.prefilling_mask[idx]
prefilling = prefilling or request_prefilling
prefilling_mask.append(request_prefilling) prefilling_mask.append(request_prefilling)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
# Get length # Get length
request_postfix_length = self.postfix_lengths[idx] request_postfix_length = self.postfix_lengths[idx]
request_prefix_length = self.prefix_lengths[idx] request_prefix_length = self.prefix_lengths[idx]
@ -538,32 +541,48 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = ( speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None self.speculative_ids[indices] if self.speculative_ids is not None else None
) )
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64) if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None
start_slots=None
slot_indices=None
slots=None
prefix_lengths_tensor=None
postfix_lengths_tensor=None
adapter_meta=None
else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor start_slots = torch.tensor(start_slots, dtype=torch.int64)
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) # Move to GPU now that we have the whole tensor
adapter_segments = torch.tensor( slot_indices = slot_indices.to(device)
adapter_segments, dtype=torch.int32, device=device
) adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
@ -580,7 +599,7 @@ class FlashCausalLMBatch(Batch):
slots=slots, slots=slots,
max_postfix_length=max_postfix_length, max_postfix_length=max_postfix_length,
max_current_length=max_current_length, max_current_length=max_current_length,
prefilling=prefilling, prefilling=self.prefilling,
prefilling_mask=prefilling_mask, prefilling_mask=prefilling_mask,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
@ -604,12 +623,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata( adapter_meta=adapter_meta,
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
) )
@classmethod @classmethod
@ -652,38 +666,51 @@ class FlashCausalLMBatch(Batch):
) )
prefilling = prefilling or b.prefilling prefilling = prefilling or b.prefilling
input_ids = batches[0].input_ids.new_empty(total_batch_size) if prefilling:
position_ids = batches[0].position_ids.new_empty(total_batch_size) input_ids = []
slots = batches[0].slots.new_empty(total_slots) # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) position_ids=None
start_slots=None
slots=None
slot_indices=None
prefix_lengths_tensor=None
postfix_lengths_tensor=None
adapter_meta=None
adapter_segment_builder=None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_slots = []
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
total_batch_size
)
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
total_batch_size
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
total_batch_size
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size, total_batch_size,
) )
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = []
block_tables = [] block_tables = []
prefix_lengths = [] prefix_lengths = []
all_input_ids = [] all_input_ids = []
@ -723,29 +750,7 @@ class FlashCausalLMBatch(Batch):
slots_end_index = cumulative_slots + len(batch.slots) slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU) # Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length] ] = batch.all_input_ids_tensor[:, :max_length]
@ -753,12 +758,38 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor[ block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor if not prefilling:
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
start_slots.append(batch.start_slots + cumulative_slots) # Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
prefilling_mask = prefilling_mask.extend(batch.prefilling_mask) start_slots.append(batch.start_slots + cumulative_slots)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
input_ids.extend(batch.input_ids)
prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lengths.extend(batch.prefix_lengths) prefix_lengths.extend(batch.prefix_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
@ -781,7 +812,8 @@ class FlashCausalLMBatch(Batch):
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
start_slots = torch.concat(start_slots) if start_slots is not None:
start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
@ -799,7 +831,14 @@ class FlashCausalLMBatch(Batch):
else None else None
) )
adapter_segments, adapter_segment_indices = adapter_segment_builder.build() if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
@ -840,12 +879,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata( adapter_meta=adapter_meta,
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
) )
def prepare_for_prefill(self): def prepare_for_prefill(self):
@ -973,9 +1007,16 @@ class FlashCausalLMBatch(Batch):
cumulative_length += next_chunk_length cumulative_length += next_chunk_length
cumulative_slot_tokens += len(request_slots) cumulative_slot_tokens += len(request_slots)
device = self.input_ids.device device = self.block_tables_tensor.device
self.start_slots = torch.tensor(start_slots, dtype=torch.int64) self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
if len(self) > 1: if len(self) > 1:
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
@ -1865,7 +1906,8 @@ class FlashCausalLM(Model):
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.postfix_lengths_tensor += accepted_ids batch.prefix_lengths_tensor += batch.postfix_lengths_tensor
batch.postfix_lengths_tensor = accepted_ids
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices batch.adapter_meta.adapter_indices = next_adapter_indices
@ -1929,11 +1971,9 @@ class FlashCausalLM(Model):
# This request is done prefilling, the new id is the one selected the sampling method # This request is done prefilling, the new id is the one selected the sampling method
postfix_ids = [next_token_id] postfix_ids = [next_token_id]
all_postfix_ids.extend(postfix_ids) all_postfix_ids.append(postfix_ids)
batch.input_ids = batch.input_ids.new_tensor( batch.input_ids = all_postfix_ids
all_postfix_ids, dtype=torch.int64
)
start_decode = time.time_ns() start_decode = time.time_ns()
@ -2014,7 +2054,7 @@ class FlashCausalLM(Model):
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefix_ids + prefill_token_ids, prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,
prefill_texts, prefill_texts,
is_special=[], is_special=[],