diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index 7547623c..5abc066c 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -3,7 +3,7 @@ import torch from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Dict +from typing import Optional, Tuple, List, Type, Dict, Union from loguru import logger from text_generation_server.models import Model @@ -28,7 +28,7 @@ class VectorizedCausalLMBatch(Batch): # Decoder values attention_mask: torch.Tensor position_ids: torch.Tensor - past_key_values: Optional[List[Tuple]] + past_key_values: Optional[List[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]]] # All tokens input_ids: torch.Tensor @@ -65,30 +65,14 @@ class VectorizedCausalLMBatch(Batch): tokenizer: PreTrainedTokenizerBase, device: torch.device, ) -> "VectorizedCausalLMBatch": - inputs = [] - stopping_criterias = [] - offsets = [] - token_offsets = [] - requests_idx_mapping = {} + inputs = [r.inputs for r in pb.requests] + offsets = [None]*len(inputs) + token_offsets = [None]*len(inputs) + requests_idx_mapping = {r.id:i for i, r in enumerate(pb.requests)} # Parse batch - max_truncation = 0 - padding_right_offset = 0 - max_decode_tokens = 0 - for i, r in enumerate(pb.requests): - requests_idx_mapping[r.id] = i - inputs.append(r.inputs) - offsets.append(None) - token_offsets.append(None) - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - stopping_criterias.append(stopping_criteria) - max_truncation = max(max_truncation, r.truncate) - max_decode_tokens += stopping_criteria.max_new_tokens - padding_right_offset = max( - padding_right_offset, stopping_criteria.max_new_tokens - ) + stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests] + max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias) next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device) @@ -98,13 +82,13 @@ class VectorizedCausalLMBatch(Batch): padding=True, return_token_type_ids=False, truncation=True, - max_length=max_truncation, + max_length=max(r.truncate for r in pb.requests), ).to(device) input_lengths = tokenized_inputs["attention_mask"].sum(1) - max_input_length = input_lengths.max() + max_input_length = input_lengths.max().item() - input_shape=(pb.size, max_input_length + padding_right_offset) + input_shape=(pb.size, max_input_length + max(max_new_tokens)) # Allocate maximum attention_mask attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) @@ -118,7 +102,7 @@ class VectorizedCausalLMBatch(Batch): input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) input_ids[:, :max_input_length].copy_(tokenized_inputs["input_ids"]) - max_tokens = len(inputs) * max_input_length + max_decode_tokens + max_tokens = len(inputs) * max_input_length + sum(max_new_tokens) return cls( batch_id=pb.id, @@ -155,11 +139,10 @@ class VectorizedCausalLMBatch(Batch): self.next_token_chooser=self.next_token_chooser.filter(keep_indices) self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices] remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias] - self.padding_right_offset=max(remaining_decode_tokens) # Select the remaining indices and remove unnecessary padding max_input_length=max(self.input_lengths) - sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+self.padding_right_offset) + sequence_slice=slice(self.max_input_length-max_input_length, self.max_input_length+max(remaining_decode_tokens)) self.max_input_length=max_input_length self.max_tokens = len(self.requests) * self.max_input_length + sum(remaining_decode_tokens) @@ -189,7 +172,109 @@ class VectorizedCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["VectorizedCausalLMBatch"]) -> "VectorizedCausalLMBatch": - raise NotImplementedError() + if len(batches)==0: + raise ValueError("Cannot concatenate empty list.") + requests=[request for batch in batches for request in batch.requests] + batch_sizes=[len(batch.requests) for batch in batches] + batch_size=sum(batch_sizes) + + end_indices=torch.tensor(batch_sizes).cumsum(0).tolist() + start_indices=[0]+end_indices[:-1] + + input_lengths = [length for batch in batches for length in batch.input_lengths] + offsets = [offset for batch in batches for offset in batch.offsets] + token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets] + next_token_chooser=VectorizedNextTokenChooser.concatenate([batch.next_token_chooser for batch in batches]) + stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias] + + requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()} + + max_input_length=max(input_lengths) + left_indices=[max_input_length-batch.max_input_length for batch in batches] + + input_shape=(batch_size, max_input_length + max(batch.input_ids.size(1)-batch.max_input_length for batch in batches)) + device=batches[0].input_ids.device + + # Allocate maximum attention_mask + attention_mask = torch.empty(input_shape, dtype=torch.bool, device=device) + attention_mask[:, :max_input_length].fill_(0) + attention_mask[:, max_input_length:].fill_(1) + + input_ids = torch.empty(input_shape, dtype=torch.int64, device=device) + # TODO : only needed for prefill + input_ids[:, :max_input_length].fill_(0) + + for batch,start_index, end_index, left_index in zip(batches, start_indices, end_indices, left_indices): + attention_mask[start_index:end_index, left_index:max_input_length].copy_(batch.attention_mask[:, :batch.max_input_length]) + input_ids[start_index:end_index, left_index:max_input_length].copy_(batch.input_ids[:, :batch.max_input_length]) + + position_ids = attention_mask.cumsum(-1).sub_(1) + position_ids[:, :max_input_length].relu_() + + max_tokens = sum(batch.max_tokens + (max_input_length - batch.max_input_length) * len(batch) for batch in batches) + + kv_formats=None + for batch in batches: + if batch.past_key_values is None: + raise ValueError("Only concatenate prefilled batches") + if not isinstance(batch.past_key_values, (list, tuple)): + raise NotImplementedError(f"Unsupported kv cache type: {type(batch.past_key_values)}") + if kv_formats is None: + num_layers=len(batch.past_key_values) + if num_layers==0: + raise ValueError("Empty KV cache") + kv_formats = [0]*num_layers + elif len(batch.past_key_values)!=len(kv_formats): + raise ValueError("Num layers is not constant") + for i, layer_kv in enumerate(batch.past_key_values): + if isinstance(layer_kv, (list, tuple)): + kv_format = len(layer_kv) + else: + kv_format=None + if kv_formats[i]==0: + if kv_format==0: + raise ValueError("Empty KV cache") + kv_formats[i]=kv_format + elif kv_formats[i]!=kv_format: + raise ValueError("Incompatible KV cache format.") + + kv_cache_seq_dim=batches[0].kv_cache_seq_dim + past_key_values=[] + for i, kv_format in enumerate(kv_formats): + for j in range(1 if kv_format is None else kv_format): + tensors_to_merge=[batch.past_key_values[i] for batch in batches] + # Generally `max_input_length`, unless the model allocates more than needed. + right_indices=[left_index+tensor.size(kv_cache_seq_dim) for tensor, left_index in zip(tensors_to_merge, left_indices)] + combined_shape=[batch_size]+list(tensors_to_merge[0].shape[1:]) + combined_shape[kv_cache_seq_dim]=max(right_indices) + # Set to zero to avoid propagating nans in padded values. + kv_cache = torch.zeros(combined_shape, dtype=tensors_to_merge[0].dtype, device=device) + for tensor, start_index, end_index, left_index, right_index in zip(tensors_to_merge, start_indices, end_indices, left_indices, right_indices): + kv_cache[[slice(start_index, end_index), *(slice(None) for _ in range(1, kv_cache_seq_dim)), slice(left_index,right_index)]].copy_(tensor) + if kv_format is None: + past_key_values.append(kv_cache) + elif j==0: + past_key_values.append([kv_cache]) + else: + past_key_values[-1].append(kv_cache) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + kv_cache_seq_dim=kv_cache_seq_dim, + max_tokens=max_tokens, + ) def __len__(self): return len(self.requests) @@ -382,6 +467,21 @@ class VectorizedNextTokenChooser: device=self.device, ) + @classmethod + def concatenate(cls, next_token_choosers: List["VectorizedNextTokenChooser"]) -> "VectorizedNextTokenChooser": + return cls( + batch_size=sum(next_token_chooser.batch_size for next_token_chooser in next_token_choosers), + watermark=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.watermark], + temperature=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.temperature], + repetition_penalty=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.repetition_penalty], + top_k=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_k], + top_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_p], + typical_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.typical_p], + do_sample=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.do_sample], + seeds=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.seeds], + device=next_token_choosers[0].device, + ) +