mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
concatenate
This commit is contained in:
parent
7a70928b06
commit
46363e1cd7
@ -3,7 +3,7 @@ import torch
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict, Union
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -28,7 +28,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
# Decoder values
|
# Decoder values
|
||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
@ -65,30 +65,14 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VectorizedCausalLMBatch":
|
) -> "VectorizedCausalLMBatch":
|
||||||
inputs = []
|
inputs = [r.inputs for r in pb.requests]
|
||||||
stopping_criterias = []
|
offsets = [None]*len(inputs)
|
||||||
offsets = []
|
token_offsets = [None]*len(inputs)
|
||||||
token_offsets = []
|
requests_idx_mapping = {r.id:i for i, r in enumerate(pb.requests)}
|
||||||
requests_idx_mapping = {}
|
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_truncation = 0
|
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
|
||||||
padding_right_offset = 0
|
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
|
||||||
max_decode_tokens = 0
|
|
||||||
for i, r in enumerate(pb.requests):
|
|
||||||
requests_idx_mapping[r.id] = i
|
|
||||||
inputs.append(r.inputs)
|
|
||||||
offsets.append(None)
|
|
||||||
token_offsets.append(None)
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
|
||||||
r.stopping_parameters, tokenizer
|
|
||||||
)
|
|
||||||
stopping_criterias.append(stopping_criteria)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
|
||||||
padding_right_offset = max(
|
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
||||||
|
|
||||||
@ -98,13 +82,13 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
padding=True,
|
padding=True,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max(r.truncate for r in pb.requests),
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max().item()
|
||||||
|
|
||||||
input_shape=(pb.size, max_input_length + padding_right_offset)
|
input_shape=(pb.size, max_input_length + max(max_new_tokens))
|
||||||
|
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
@ -118,7 +102,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
||||||
input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"])
|
input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"])
|
||||||
|
|
||||||
max_tokens = len(inputs) * max_input_length + max_decode_tokens
|
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -155,11 +139,10 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
self.next_token_chooser=self.next_token_chooser.filter(keep_indices)
|
self.next_token_chooser=self.next_token_chooser.filter(keep_indices)
|
||||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
||||||
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
||||||
self.padding_right_offset=max(remaining_decode_tokens)
|
|
||||||
|
|
||||||
# Select the remaining indices and remove unnecessary padding
|
# Select the remaining indices and remove unnecessary padding
|
||||||
max_input_length=max(self.input_lengths)
|
max_input_length=max(self.input_lengths)
|
||||||
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+self.padding_right_offset)
|
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens))
|
||||||
self.max_input_length=max_input_length
|
self.max_input_length=max_input_length
|
||||||
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens)
|
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens)
|
||||||
|
|
||||||
@ -189,7 +172,109 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
|
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
|
||||||
raise NotImplementedError()
|
if len(batches)==0:
|
||||||
|
raise ValueError("Cannot concatenate empty list.")
|
||||||
|
requests=[request for batch in batches for request in batch.requests]
|
||||||
|
batch_sizes=[len(batch.requests) for batch in batches]
|
||||||
|
batch_size=sum(batch_sizes)
|
||||||
|
|
||||||
|
end_indices=torch.tensor(batch_sizes).cumsum(0).tolist()
|
||||||
|
start_indices=[0]+end_indices[:-1]
|
||||||
|
|
||||||
|
input_lengths = [length for batch in batches for length in batch.input_lengths]
|
||||||
|
offsets = [offset for batch in batches for offset in batch.offsets]
|
||||||
|
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
|
||||||
|
next_token_chooser=VectorizedNextTokenChooser.concatenate([batch.next_token_chooser for batch in batches])
|
||||||
|
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
|
||||||
|
|
||||||
|
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
|
||||||
|
|
||||||
|
max_input_length=max(input_lengths)
|
||||||
|
left_indices=[max_input_length-batch.max_input_length for batch in batches]
|
||||||
|
|
||||||
|
input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches))
|
||||||
|
device=batches[0].input_ids.device
|
||||||
|
|
||||||
|
# Allocate maximum attention_mask
|
||||||
|
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
|
||||||
|
attention_mask[:, :max_input_length].fill_(0)
|
||||||
|
attention_mask[:, max_input_length:].fill_(1)
|
||||||
|
|
||||||
|
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
|
||||||
|
# TODO : only needed for prefill
|
||||||
|
input_ids[:, :max_input_length].fill_(0)
|
||||||
|
|
||||||
|
for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices):
|
||||||
|
attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length])
|
||||||
|
input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length])
|
||||||
|
|
||||||
|
position_ids = attention_mask.cumsum(-1).sub_(1)
|
||||||
|
position_ids[:, :max_input_length].relu_()
|
||||||
|
|
||||||
|
max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches)
|
||||||
|
|
||||||
|
kv_formats=None
|
||||||
|
for batch in batches:
|
||||||
|
if batch.past_key_values is None:
|
||||||
|
raise ValueError("Only concatenate prefilled batches")
|
||||||
|
if not isinstance(batch.past_key_values, (list, tuple)):
|
||||||
|
raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}")
|
||||||
|
if kv_formats is None:
|
||||||
|
num_layers=len(batch.past_key_values)
|
||||||
|
if num_layers==0:
|
||||||
|
raise ValueError("Empty KV cache")
|
||||||
|
kv_formats = [0]*num_layers
|
||||||
|
elif len(batch.past_key_values)!=len(kv_formats):
|
||||||
|
raise ValueError("Num layers is not constant")
|
||||||
|
for i, layer_kv in enumerate(batch.past_key_values):
|
||||||
|
if isinstance(layer_kv, (list, tuple)):
|
||||||
|
kv_format = len(layer_kv)
|
||||||
|
else:
|
||||||
|
kv_format=None
|
||||||
|
if kv_formats[i]==0:
|
||||||
|
if kv_format==0:
|
||||||
|
raise ValueError("Empty KV cache")
|
||||||
|
kv_formats[i]=kv_format
|
||||||
|
elif kv_formats[i]!=kv_format:
|
||||||
|
raise ValueError("Incompatible KV cache format.")
|
||||||
|
|
||||||
|
kv_cache_seq_dim=batches[0].kv_cache_seq_dim
|
||||||
|
past_key_values=[]
|
||||||
|
for i, kv_format in enumerate(kv_formats):
|
||||||
|
for j in range(1 if kv_format is None else kv_format):
|
||||||
|
tensors_to_merge=[batch.past_key_values[i] for batch in batches]
|
||||||
|
# Generally `max_input_length`, unless the model allocates more than needed.
|
||||||
|
right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)]
|
||||||
|
combined_shape=[batch_size]+list(tensors_to_merge[0].shape[1:])
|
||||||
|
combined_shape[kv_cache_seq_dim]=max(right_indices)
|
||||||
|
# Set to zero to avoid propagating nans in padded values.
|
||||||
|
kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device)
|
||||||
|
for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices):
|
||||||
|
kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor)
|
||||||
|
if kv_format is None:
|
||||||
|
past_key_values.append(kv_cache)
|
||||||
|
elif j==0:
|
||||||
|
past_key_values.append([kv_cache])
|
||||||
|
else:
|
||||||
|
past_key_values[-1].append(kv_cache)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
|
next_token_chooser=next_token_chooser,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
kv_cache_seq_dim=kv_cache_seq_dim,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
@ -382,6 +467,21 @@ class VectorizedNextTokenChooser:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, next_token_choosers: List["VectorizedNextTokenChooser"]) -> "VectorizedNextTokenChooser":
|
||||||
|
return cls(
|
||||||
|
batch_size=sum(next_token_chooser.batch_size for next_token_chooser in next_token_choosers),
|
||||||
|
watermark=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.watermark],
|
||||||
|
temperature=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.temperature],
|
||||||
|
repetition_penalty=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.repetition_penalty],
|
||||||
|
top_k=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_k],
|
||||||
|
top_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_p],
|
||||||
|
typical_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.typical_p],
|
||||||
|
do_sample=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.do_sample],
|
||||||
|
seeds=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.seeds],
|
||||||
|
device=next_token_choosers[0].device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user