mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
working
This commit is contained in:
parent
173bc99ab3
commit
34f5dc525e
@ -16,7 +16,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
)
|
||||
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
|
||||
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
@ -119,7 +119,9 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
# Decoder values
|
||||
input_ids: torch.Tensor
|
||||
# Can be a list for easy filtering
|
||||
# If `input_ids` is a list, it needs to be materialized to a tensor first
|
||||
input_ids: Union[torch.Tensor, List[List[int]]]
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
position_ids: Optional[torch.Tensor]
|
||||
speculative_ids: Optional[torch.Tensor]
|
||||
@ -178,7 +180,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
postfix_lengths_tensor: Optional[torch.Tensor]
|
||||
prefix_lengths_tensor: Optional[torch.Tensor]
|
||||
prompt_lengths_tensor: Optional[torch.Tensor]
|
||||
prompt_lengths_tensor: torch.Tensor
|
||||
|
||||
prefix_offsets: List[Optional[int]]
|
||||
read_offsets: List[Optional[int]]
|
||||
@ -350,12 +352,6 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_postfix_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = all_postfix_ids[0]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
@ -366,12 +362,15 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
prompt_lengths_tensor = torch.tensor(
|
||||
prompt_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_ids=all_postfix_ids,
|
||||
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
@ -395,6 +394,7 @@ class FlashCausalLMBatch(Batch):
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=None,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids=None,
|
||||
@ -408,7 +408,6 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens=None,
|
||||
prefix_lengths_tensor=None,
|
||||
postfix_lengths_tensor=None,
|
||||
prompt_lengths_tensor=None,
|
||||
adapter_meta=None,
|
||||
)
|
||||
|
||||
@ -455,6 +454,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
input_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
postfix_lengths = []
|
||||
@ -473,7 +473,6 @@ class FlashCausalLMBatch(Batch):
|
||||
max_blocks = 0
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
prefilling=False
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
@ -484,9 +483,13 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Prefilling
|
||||
request_prefilling = self.prefilling_mask[idx]
|
||||
prefilling = prefilling or request_prefilling
|
||||
prefilling_mask.append(request_prefilling)
|
||||
|
||||
# Input ids if the request was part of a prefilling batch
|
||||
# If the batch was decoding we can index into the tensor directly later
|
||||
if self.prefilling:
|
||||
input_ids.append(self.input_ids[idx])
|
||||
|
||||
# Get length
|
||||
request_postfix_length = self.postfix_lengths[idx]
|
||||
request_prefix_length = self.prefix_lengths[idx]
|
||||
@ -538,21 +541,32 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
speculative_ids = (
|
||||
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||
)
|
||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||
|
||||
if self.prefilling:
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids=None
|
||||
start_slots=None
|
||||
slot_indices=None
|
||||
slots=None
|
||||
prefix_lengths_tensor=None
|
||||
postfix_lengths_tensor=None
|
||||
adapter_meta=None
|
||||
else:
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -563,7 +577,12 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
@ -580,7 +599,7 @@ class FlashCausalLMBatch(Batch):
|
||||
slots=slots,
|
||||
max_postfix_length=max_postfix_length,
|
||||
max_current_length=max_current_length,
|
||||
prefilling=prefilling,
|
||||
prefilling=self.prefilling,
|
||||
prefilling_mask=prefilling_mask,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
@ -604,12 +623,7 @@ class FlashCausalLMBatch(Batch):
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
adapter_meta=adapter_meta,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -652,38 +666,51 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
prefilling = prefilling or b.prefilling
|
||||
|
||||
if prefilling:
|
||||
input_ids = []
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids=None
|
||||
start_slots=None
|
||||
slots=None
|
||||
slot_indices=None
|
||||
prefix_lengths_tensor=None
|
||||
postfix_lengths_tensor=None
|
||||
adapter_meta=None
|
||||
adapter_segment_builder=None
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
start_slots = []
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||
(total_batch_size, max_blocks)
|
||||
)
|
||||
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_set = set()
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
adapter_set = set()
|
||||
|
||||
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||
(total_batch_size, max_blocks)
|
||||
)
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
prefix_lengths = []
|
||||
all_input_ids = []
|
||||
@ -723,12 +750,21 @@ class FlashCausalLMBatch(Batch):
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
|
||||
# Copy tensors (GPU)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
||||
block_tables_tensor[
|
||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||
] = batch.block_tables_tensor[:, :max_blocks]
|
||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||
|
||||
if not prefilling:
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
# Copy over adapter indices
|
||||
@ -745,20 +781,15 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
||||
)
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
||||
block_tables_tensor[
|
||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||
] = batch.block_tables_tensor[:, :max_blocks]
|
||||
|
||||
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
|
||||
|
||||
start_slots.append(batch.start_slots + cumulative_slots)
|
||||
else:
|
||||
if isinstance(batch.input_ids, torch.Tensor):
|
||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||
input_ids.extend(batch.input_ids)
|
||||
|
||||
prefilling_mask = prefilling_mask.extend(batch.prefilling_mask)
|
||||
prefilling_mask.extend(batch.prefilling_mask)
|
||||
block_tables.extend(batch.block_tables)
|
||||
prefix_lengths.extend(batch.prefix_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
@ -781,6 +812,7 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_batch_size += len(batch)
|
||||
cumulative_slots += len(batch.slots)
|
||||
|
||||
if start_slots is not None:
|
||||
start_slots = torch.concat(start_slots)
|
||||
|
||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||
@ -799,7 +831,14 @@ class FlashCausalLMBatch(Batch):
|
||||
else None
|
||||
)
|
||||
|
||||
if adapter_segment_builder is not None:
|
||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
@ -840,12 +879,7 @@ class FlashCausalLMBatch(Batch):
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
adapter_meta=adapter_meta,
|
||||
)
|
||||
|
||||
def prepare_for_prefill(self):
|
||||
@ -973,9 +1007,16 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_length += next_chunk_length
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
device = self.input_ids.device
|
||||
device = self.block_tables_tensor.device
|
||||
self.start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
if isinstance(self.input_ids, list):
|
||||
if len(self) > 1:
|
||||
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = self.input_ids[0]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
if len(self) > 1:
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
@ -1865,7 +1906,8 @@ class FlashCausalLM(Model):
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.postfix_lengths_tensor += accepted_ids
|
||||
batch.prefix_lengths_tensor += batch.postfix_lengths_tensor
|
||||
batch.postfix_lengths_tensor = accepted_ids
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
@ -1929,11 +1971,9 @@ class FlashCausalLM(Model):
|
||||
# This request is done prefilling, the new id is the one selected the sampling method
|
||||
postfix_ids = [next_token_id]
|
||||
|
||||
all_postfix_ids.extend(postfix_ids)
|
||||
all_postfix_ids.append(postfix_ids)
|
||||
|
||||
batch.input_ids = batch.input_ids.new_tensor(
|
||||
all_postfix_ids, dtype=torch.int64
|
||||
)
|
||||
batch.input_ids = all_postfix_ids
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
@ -2014,7 +2054,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
prefill_tokens = Tokens(
|
||||
prefix_ids + prefill_token_ids,
|
||||
prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
|
Loading…
Reference in New Issue
Block a user