concatenate

This commit is contained in:
Joel Lamy-Poirier 2023-05-04 17:34:28 -04:00
parent 7a70928b06
commit 46363e1cd7
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -3,7 +3,7 @@ import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from typing import Optional, Tuple, List, Type, Dict, Union
from loguru import logger
from text_generation_server.models import Model
@ -28,7 +28,7 @@ class VectorizedCausalLMBatch(Batch):
# Decoder values
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]]
past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]]
# All tokens
input_ids: torch.Tensor
@ -65,30 +65,14 @@ class VectorizedCausalLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "VectorizedCausalLMBatch":
inputs = []
stopping_criterias = []
offsets = []
token_offsets = []
requests_idx_mapping = {}
inputs = [r.inputs for r in pb.requests]
offsets = [None]*len(inputs)
token_offsets = [None]*len(inputs)
requests_idx_mapping = {r.id:i for i, r in enumerate(pb.requests)}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
@ -98,13 +82,13 @@ class VectorizedCausalLMBatch(Batch):
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
max_length=max(r.truncate for r in pb.requests),
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
max_input_length = input_lengths.max().item()
input_shape=(pb.size, max_input_length + padding_right_offset)
input_shape=(pb.size, max_input_length + max(max_new_tokens))
# Allocate maximum attention_mask
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
@ -118,7 +102,7 @@ class VectorizedCausalLMBatch(Batch):
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"])
max_tokens = len(inputs) * max_input_length + max_decode_tokens
max_tokens = len(inputs) * max_input_length + sum(max_new_tokens)
return cls(
batch_id=pb.id,
@ -155,11 +139,10 @@ class VectorizedCausalLMBatch(Batch):
self.next_token_chooser=self.next_token_chooser.filter(keep_indices)
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
self.padding_right_offset=max(remaining_decode_tokens)
# Select the remaining indices and remove unnecessary padding
max_input_length=max(self.input_lengths)
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+self.padding_right_offset)
sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens))
self.max_input_length=max_input_length
self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens)
@ -189,7 +172,109 @@ class VectorizedCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch":
raise NotImplementedError()
if len(batches)==0:
raise ValueError("Cannot concatenate empty list.")
requests=[request for batch in batches for request in batch.requests]
batch_sizes=[len(batch.requests) for batch in batches]
batch_size=sum(batch_sizes)
end_indices=torch.tensor(batch_sizes).cumsum(0).tolist()
start_indices=[0]+end_indices[:-1]
input_lengths = [length for batch in batches for length in batch.input_lengths]
offsets = [offset for batch in batches for offset in batch.offsets]
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
next_token_chooser=VectorizedNextTokenChooser.concatenate([batch.next_token_chooser for batch in batches])
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
max_input_length=max(input_lengths)
left_indices=[max_input_length-batch.max_input_length for batch in batches]
input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches))
device=batches[0].input_ids.device
# Allocate maximum attention_mask
attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device)
attention_mask[:, :max_input_length].fill_(0)
attention_mask[:, max_input_length:].fill_(1)
input_ids = torch.empty(input_shape, dtype=torch.int64, device=device)
# TODO : only needed for prefill
input_ids[:, :max_input_length].fill_(0)
for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices):
attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length])
input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length])
position_ids = attention_mask.cumsum(-1).sub_(1)
position_ids[:, :max_input_length].relu_()
max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches)
kv_formats=None
for batch in batches:
if batch.past_key_values is None:
raise ValueError("Only concatenate prefilled batches")
if not isinstance(batch.past_key_values, (list, tuple)):
raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}")
if kv_formats is None:
num_layers=len(batch.past_key_values)
if num_layers==0:
raise ValueError("Empty KV cache")
kv_formats = [0]*num_layers
elif len(batch.past_key_values)!=len(kv_formats):
raise ValueError("Num layers is not constant")
for i, layer_kv in enumerate(batch.past_key_values):
if isinstance(layer_kv, (list, tuple)):
kv_format = len(layer_kv)
else:
kv_format=None
if kv_formats[i]==0:
if kv_format==0:
raise ValueError("Empty KV cache")
kv_formats[i]=kv_format
elif kv_formats[i]!=kv_format:
raise ValueError("Incompatible KV cache format.")
kv_cache_seq_dim=batches[0].kv_cache_seq_dim
past_key_values=[]
for i, kv_format in enumerate(kv_formats):
for j in range(1 if kv_format is None else kv_format):
tensors_to_merge=[batch.past_key_values[i] for batch in batches]
# Generally `max_input_length`, unless the model allocates more than needed.
right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)]
combined_shape=[batch_size]+list(tensors_to_merge[0].shape[1:])
combined_shape[kv_cache_seq_dim]=max(right_indices)
# Set to zero to avoid propagating nans in padded values.
kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device)
for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices):
kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor)
if kv_format is None:
past_key_values.append(kv_cache)
elif j==0:
past_key_values.append([kv_cache])
else:
past_key_values[-1].append(kv_cache)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
kv_cache_seq_dim=kv_cache_seq_dim,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
@ -382,6 +467,21 @@ class VectorizedNextTokenChooser:
device=self.device,
)
@classmethod
def concatenate(cls, next_token_choosers: List["VectorizedNextTokenChooser"]) -> "VectorizedNextTokenChooser":
return cls(
batch_size=sum(next_token_chooser.batch_size for next_token_chooser in next_token_choosers),
watermark=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.watermark],
temperature=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.temperature],
repetition_penalty=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.repetition_penalty],
top_k=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_k],
top_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_p],
typical_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.typical_p],
do_sample=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.do_sample],
seeds=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.seeds],
device=next_token_choosers[0].device,
)