diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 5eddc8cf..088a1457 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": batch = super(BloomCausalLMBatch, cls).from_pb( - pb=pb, tokenizer=tokenizer, device=device + pb=pb, tokenizer=tokenizer, dtype=dtype, device=device ) batch.keys_head_dim_last = False return batch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 81a5e75e..a20a6143 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -66,6 +66,7 @@ class CausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": inputs = [] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index baa6cd7f..fb98386f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -19,9 +19,8 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( - NextTokenChooser, StoppingCriteria, - Sampling, + HeterogeneousNextTokenChooser ) tracer = trace.get_tracer(__name__) @@ -48,7 +47,7 @@ class FlashCausalLMBatch(Batch): # All tokens all_input_ids: List[List[int]] - all_input_ids_tensor: List[torch.Tensor] + all_input_ids_tensor: torch.Tensor # Lengths of all generations present in the batch input_lengths: List[int] @@ -56,7 +55,7 @@ class FlashCausalLMBatch(Batch): read_offsets: List[Optional[int]] # Generation helpers - next_token_choosers: List[NextTokenChooser] + next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] # Maximum number of tokens this batch will grow to @@ -72,10 +71,11 @@ class FlashCausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": position_ids = [] cu_seqlens = [0] @@ -87,13 +87,14 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] requests_idx_mapping = {} - next_token_choosers = [] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_length = 0 max_tokens = 0 + max_length = 0 # Parse batch for i, r in enumerate(pb.requests): @@ -119,7 +120,7 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -130,11 +131,26 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens + max_length = max(max_length, input_length + max_new_tokens) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype, device + ) + + # Padded all_input_ids_tensor + all_input_ids_tensor = np.zeros( + (len(all_input_ids), max_length), dtype=np.int64 + ) + for i, input_ids in enumerate(all_input_ids): + all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device input_ids = torch.tensor( np.concatenate(all_input_ids), dtype=torch.int64, device=device ) + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) position_ids = torch.tensor( np.concatenate(position_ids), dtype=torch.int32, device=device ) @@ -154,8 +170,8 @@ class FlashCausalLMBatch(Batch): prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, - all_input_ids_tensor=[], - next_token_choosers=next_token_choosers, + all_input_ids_tensor=all_input_ids_tensor, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=max_tokens, ) @@ -176,31 +192,29 @@ class FlashCausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} - input_ids = self.input_ids.new_empty(len(request_ids)) - position_ids = self.position_ids.new_empty(len(request_ids)) + # Used to index into tensors + indices = [] + # Create on CPU to only move to GPU once instead of at every copy cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) - cu_seqlens_q = torch.arange( - 0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 - ) + cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1] max_seqlen = 0 past_key_values = [] requests = [] all_input_ids = [] - all_input_ids_tensor = [] input_lengths = [] prefix_offsets = [] read_offsets = [] - next_token_choosers = [] stopping_criterias = [] max_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] + indices.append(idx) requests_idx_mapping[request_id] = i requests.append(self.requests[idx]) @@ -208,34 +222,27 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] - # Copy tensors (GPU) - input_ids[i] = self.input_ids[idx] - position_ids[i] = self.position_ids[idx] - # Copy to tensor (CPU) cu_seqlens[i + 1] = cumulative_length + request_input_length max_seqlen = max(max_seqlen, request_input_length) # Slice from past past_key_values.append( - self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] + self.past_key_values[:, self.cu_seqlens[idx]: self.cu_seqlens[idx + 1]] ) all_input_ids.append(self.all_input_ids[idx]) - all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) input_lengths.append(request_input_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) cumulative_length += request_input_length max_tokens += request_input_length + ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) if single_request: @@ -258,6 +265,12 @@ class FlashCausalLMBatch(Batch): # Cat all past past_key_values = torch.cat(past_key_values, dim=1) + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + all_input_ids_tensor = self.all_input_ids_tensor[indices] + next_token_chooser = self.next_token_chooser.filter(indices) + # Move to GPU now that we have the whole tensor cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) @@ -276,7 +289,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=max_tokens, ) @@ -290,6 +303,7 @@ class FlashCausalLMBatch(Batch): total_batch_size = sum([len(b) for b in batches]) + dtype = batches[0].past_key_values.dtype device = batches[0].input_ids.device input_ids = batches[0].input_ids.new_empty(total_batch_size) @@ -302,19 +316,19 @@ class FlashCausalLMBatch(Batch): past_key_values = [] all_input_ids = [] - all_input_ids_tensor = [] input_lengths = [] prefix_offsets = [] read_offsets = [] - next_token_choosers = [] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_batch_size = 0 cumulative_length = 0 max_tokens = 0 + max_length = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -347,25 +361,54 @@ class FlashCausalLMBatch(Batch): ) all_input_ids.extend(batch.all_input_ids) - all_input_ids_tensor.extend(batch.all_input_ids_tensor) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) - next_token_choosers.extend(batch.next_token_choosers) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) # Update cumulative_length += batch.cu_seqlens[-1] cumulative_batch_size += len(batch) max_tokens += batch.max_tokens + max_length = max( + max_length, + max( + input_length + + stopping_criteria.max_new_tokens + - stopping_criteria.current_tokens + for input_length, stopping_criteria in zip( + batch.input_lengths, batch.stopping_criterias + ) + ), + ) + + all_input_ids_tensor = torch.zeros( + (total_batch_size, max_length), dtype=torch.int64, device=device + ) + + cumulative_batch_size = 0 + for i, batch in enumerate(batches): + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + + all_input_ids_tensor[ + start_index:end_index, : batch.all_input_ids_tensor.shape[1] + ] = batch.all_input_ids_tensor + + cumulative_batch_size += len(batch) # Cat past past_key_values = torch.cat(past_key_values, dim=1) # Create final tensor on GPU cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype=dtype, device=device + ) + return FlashCausalLMBatch( batch_id=batches[0].batch_id, requests=requests, @@ -381,7 +424,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=max_tokens, ) @@ -438,14 +481,14 @@ class FlashCausalLM(Model): ) def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - max_s: int, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + max_s: int, + past_key_values: Optional = None, + pre_allocate_past_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -460,15 +503,16 @@ class FlashCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None + single_request = len(batch) == 1 if prefill and len(batch) == 1: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( - batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens + batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens ) else: pre_allocate_past_size = None @@ -483,6 +527,17 @@ class FlashCausalLM(Model): pre_allocate_past_size, ) + if prefill: + next_token_logits = ( + out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] + ) + else: + next_token_logits = out + + next_input_ids, next_token_logprobs = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits + ) + if prefill: if len(batch) > 1: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -493,15 +548,11 @@ class FlashCausalLM(Model): batch.cu_seqlens_q = torch.arange( 0, len(batch) + 1, device=self.device, dtype=torch.int32 ) - next_input_ids = batch.input_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch)) else: prefill_logprobs = None - next_input_ids = batch.input_ids next_position_ids = batch.position_ids - next_token_logprobs = out.new_empty(len(batch)) - # Prepare past for next decode if len(batch) > 1: # Used to slice next batch past @@ -552,7 +603,6 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( batch.input_lengths, - batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, ) @@ -563,31 +613,15 @@ class FlashCausalLM(Model): # For each member of the batch for i, ( - input_length, - next_token_chooser, - stopping_criteria, - all_input_ids, + input_length, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length if prefill: - # Prefill mode - # out is of shape [cumulative_sequence_lengths, vocab_size] - # only take last token logit - logits = out[end_index - 1 : end_index] - - # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty) - all_input_ids_tensor = batch.input_ids.new_empty( - input_length + stopping_criteria.max_new_tokens - ) - # Copy from batch.input_ids to all_input_ids_tensor - all_input_ids_tensor[:input_length] = batch.input_ids[ - start_index:end_index - ] - batch.all_input_ids_tensor.append(all_input_ids_tensor) - # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] @@ -596,32 +630,13 @@ class FlashCausalLM(Model): # Copy batch.input_ids to prefill_token_indices if len(batch) > 1: prefill_tokens_indices[ - start_index : end_index - 1 - ] = batch.input_ids[start_index + 1 : end_index] + start_index: end_index - 1 + ] = batch.input_ids[start_index + 1: end_index] else: # Set prefill_tokens_indices to the correct slice - prefill_tokens_indices = batch.input_ids[ - start_index + 1 : end_index - ] - else: - # Decode mode - # out is of shape [batch_size, vocab_size] - logits = out[i].view(1, -1) + prefill_tokens_indices = batch.input_ids - all_input_ids_tensor = batch.all_input_ids_tensor[i] - - # Select next token - next_token_id, logprob = next_token_chooser( - all_input_ids_tensor[None, :input_length], logits - ) - - # Add to all_input_ids_tensor - next_token_id_squeezed = next_token_id.view(1) - all_input_ids_tensor[input_length] = next_token_id_squeezed - - # Set values - next_input_ids[i] = next_token_id_squeezed - next_token_logprobs[i] = logprob[-1, next_token_id].view(1) + batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] cumulative_length += input_length @@ -651,10 +666,11 @@ class FlashCausalLM(Model): batch.input_lengths, batch.prefix_offsets, batch.read_offsets, - batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, batch.all_input_ids_tensor, + batch.next_token_chooser.do_sample, + batch.next_token_chooser.seeds, next_token_ids, next_token_logprobs, ) @@ -665,10 +681,11 @@ class FlashCausalLM(Model): input_length, prefix_offset, read_offset, - next_token_chooser, stopping_criteria, all_input_ids, all_input_ids_tensor, + do_sample, + seed, next_token_id, next_token_logprob, ) in enumerate(iterator): @@ -700,16 +717,13 @@ class FlashCausalLM(Model): if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + all_input_ids[-stopping_criteria.current_tokens:] ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) else: generated_text = None @@ -718,8 +732,8 @@ class FlashCausalLM(Model): if prefill: # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - start_index : end_index - 1 - ] + start_index: end_index - 1 + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -751,8 +765,9 @@ class FlashCausalLM(Model): batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids - batch.max_seqlen = batch.max_seqlen + 1 cumulative_length += input_length + batch.max_seqlen = batch.max_seqlen + 1 + # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 482e0f54..9f837ced 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -231,6 +231,7 @@ class FlashSantacoderSharded(FlashSantacoder): device=device, rank=rank, world_size=world_size, + decode_buffer=1, ) @staticmethod diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index bc3096c6..0a3f341b 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "GalacticaCausalLMBatch": inputs = [] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 2abb87ae..68e59dc3 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "Seq2SeqLMBatch": """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch""" diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 66a8c212..28ca8147 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -21,6 +21,7 @@ class Batch(ABC): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "Batch": raise NotImplementedError diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e47fd049..e1bd8412 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.device + request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) generations, next_batch = self.model.generate_token(batch) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 50d64518..781e59e3 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -9,13 +9,13 @@ from text_generation_server.utils.hub import ( RevisionNotFoundError, ) from text_generation_server.utils.tokens import ( - Greedy, NextTokenChooser, - Sampling, + HeterogeneousNextTokenChooser, StoppingCriteria, StopSequenceCriteria, FinishReason, ) +from text_generation_server.utils.logits_process import Sampling, Greedy __all__ = [ "convert_file", @@ -25,6 +25,7 @@ __all__ = [ "weight_hub_files", "download_weights", "EntryNotFoundError", + "HeterogeneousNextTokenChooser", "LocalEntryNotFoundError", "RevisionNotFoundError", "Greedy", diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py new file mode 100644 index 00000000..3206807f --- /dev/null +++ b/server/text_generation_server/utils/logits_process.py @@ -0,0 +1,374 @@ +import math +import torch + +from functools import lru_cache +from typing import Optional, List + +from transformers import ( + LogitsWarper, + LogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) + + +class Sampling: + def __init__(self, seed: int, device: str = "cpu"): + self.generator = torch.Generator(device) + self.generator.manual_seed(seed) + self.seed = seed + + def __call__(self, logits): + probs = torch.nn.functional.softmax(logits, -1) + # Avoid GPU<->CPU sync done by torch multinomial + # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 + q = torch.empty_like(probs).exponential_(1, generator=self.generator) + return probs.div_(q).argmax() + + +class Greedy: + def __call__(self, logits): + return logits.argmax(dim=-1) + + +class StaticWarper: + def __init__( + self, + temperature=1.0, + top_k=None, + top_p=None, + typical_p=None, + ): + self.warpers = [] + + if temperature is not None and temperature != 1.0: + temperature = float(temperature) + self.warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + self.warpers.append(TopKLogitsWarper(top_k=top_k)) + if top_p is not None and top_p < 1.0: + self.warpers.append(TopPLogitsWarper(top_p=top_p)) + if typical_p is not None and typical_p < 1.0: + self.warpers.append(TypicalLogitsWarper(mass=typical_p)) + + self.cuda_graph = None + self.static_scores = None + self.static_warped_scores = None + self.static_next_logprob = None + + def __call__(self, scores): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(self.cuda_graph): + for warper in self.warpers: + self.static_warped_scores = warper(None, self.static_scores) + + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) + + self.static_scores.copy_(scores) + self.cuda_graph.replay() + + return self.static_warped_scores, self.static_next_logprob + + +@lru_cache(10) +def static_warper( + temperature: Optional[float], + top_k: Optional[int], + top_p: Optional[float], + typical_p: Optional[float], +) -> StaticWarper: + return StaticWarper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) + + +class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + repetition_penalty (`List[float]`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + + def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): + self.penalty = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + score = torch.gather(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + + scores.scatter_(1, input_ids, score) + return scores + + def filter(self, indices): + self.penalty = self.penalty[indices] + return self + + +class HeterogeneousTemperatureLogitsWarper: + r""" + [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + temperature (`float`): + The value used to module the logits distribution. + """ + + def __init__( + self, temperature: List[float], dtype: torch.dtype, device: torch.device + ): + self.temperature = torch.tensor( + temperature, dtype=dtype, device=device + ).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores.div_(self.temperature) + return scores + + def filter(self, indices): + self.temperature = self.temperature[indices] + return self + + +class HeterogeneousTopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + top_p: List[float], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_p_opposite = 1 - torch.tensor( + top_p, dtype=dtype, device=device + ).unsqueeze(1) + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + sorted_logits, sorted_indices = torch.sort(scores, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= self.top_p_opposite + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + scores.masked_fill_(indices_to_remove, self.filter_value) + return scores + + def filter(self, indices): + self.top_p_opposite = self.top_p_opposite[indices] + return self + + +class HeterogeneousTopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + top_k: List[int], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_k = top_k + self.max_top_k = max(top_k) + # value - 1 as we will use top_k to index and python uses 0 based numbering + self.top_k_tensor = torch.tensor( + [max(x - 1, min_tokens_to_keep - 1) for x in top_k], + dtype=torch.int64, + device=device, + ).unsqueeze(1) + + # 0 is a special value that disables top_k warping for this member of the batch + disabled = [x == 0 for x in top_k] + + if any(disabled): + self.top_k_disabled_mask = torch.tensor( + disabled, dtype=torch.bool, device=device + ) + else: + self.top_k_disabled_mask = None + + self.filter_value = filter_value + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If max_top_k is superior to the vocab, we need to clamp or the warper will fail + if scores.size(-1) < self.max_top_k: + max_top_k = scores.size(-1) + top_k = torch.clamp_max(self.top_k_tensor, max_top_k) + else: + max_top_k = self.max_top_k + top_k = self.top_k_tensor + + # Get the kth score for each member of the batch + kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) + + # Mask member of kth_scores that do not want to use top_k warping + if self.top_k_disabled_mask is not None: + kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value) + + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < kth_scores + scores.masked_fill_(indices_to_remove, self.filter_value) + return scores + + def filter(self, indices): + self.top_k_tensor = self.top_k_tensor[indices] + self.top_k = [self.top_k[i] for i in indices] + self.max_top_k = max(self.top_k) + return self + + +class HeterogeneousTypicalLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language + Generation](https://arxiv.org/abs/2202.00666) for more information. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + mass (`float`): + Value of typical_p between 0 and 1 inclusive, defaults to 0.9. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + mass: List[float], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.filter_value = filter_value + self.mass = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + + scores = scores.masked_fill_(indices_to_remove, self.filter_value) + return scores + + def filter(self, indices): + self.mass = self.mass[indices] + return self + + +class HeterogeneousSampling: + r""" + Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. + """ + + def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): + self.seeds = seeds + + self.greedy_indices = [] + self.sampling_mapping = {} + for i, (sample, seed) in enumerate(zip(do_sample, seeds)): + if sample: + self.sampling_mapping[i] = Sampling(seed, device) + else: + self.greedy_indices.append(i) + + self.greedy = Greedy() + + def __call__(self, logits): + out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) + if self.greedy_indices: + out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1) + + for i, sampling in self.sampling_mapping.items(): + out[i] = sampling(logits[i]) + return out + + def filter(self, indices): + new_greedy_indices = [] + new_sampling_mapping = {} + for i, idx in enumerate(indices): + if idx in self.sampling_mapping: + new_sampling_mapping[i] = self.sampling_mapping[idx] + else: + new_greedy_indices.append(i) + + self.greedy_indices = new_greedy_indices + self.sampling_mapping = new_sampling_mapping + return self + + diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e9fb96b0..e3b28b90 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,110 +1,33 @@ import re import torch -from functools import lru_cache from transformers import ( - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - TypicalLogitsWarper, RepetitionPenaltyLogitsProcessor, - PreTrainedTokenizerBase, + PreTrainedTokenizerBase, LogitsProcessorList, ) from typing import List, Tuple, Optional from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.watermark import WatermarkLogitsProcessor - - -class Sampling: - def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator(device) - self.generator.manual_seed(seed) - self.seed = seed - - def __call__(self, logits): - probs = torch.nn.functional.softmax(logits, -1) - # Avoid GPU<->CPU sync done by torch multinomial - # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 - q = torch.empty_like(probs).exponential_(1, generator=self.generator) - return probs.div_(q).argmax() - - -class Greedy: - def __call__(self, logits): - return logits.argmax() - - -class StaticWarper: - def __init__( - self, - temperature=1.0, - top_k=None, - top_p=None, - typical_p=None, - ): - self.warpers = [] - - if temperature is not None and temperature != 1.0: - temperature = float(temperature) - self.warpers.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - self.warpers.append(TopKLogitsWarper(top_k=top_k)) - if top_p is not None and top_p < 1.0: - self.warpers.append(TopPLogitsWarper(top_p=top_p)) - if typical_p is not None and typical_p < 1.0: - self.warpers.append(TypicalLogitsWarper(mass=typical_p)) - - self.cuda_graph = None - self.static_scores = None - self.static_warped_scores = None - self.static_next_logprob = None - - def __call__(self, scores): - if self.cuda_graph is None: - self.static_scores = scores - self.cuda_graph = torch.cuda.CUDAGraph() - - with torch.cuda.graph(self.cuda_graph): - for warper in self.warpers: - self.static_warped_scores = warper(None, self.static_scores) - - # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) - - self.static_scores.copy_(scores) - self.cuda_graph.replay() - - return self.static_warped_scores, self.static_next_logprob - - -@lru_cache(10) -def static_warper( - temperature: Optional[float], - top_k: Optional[int], - top_p: Optional[float], - typical_p: Optional[float], -) -> StaticWarper: - return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p - ) +from text_generation_server.utils import Sampling, Greedy +from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \ + HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \ + HeterogeneousTypicalLogitsWarper, HeterogeneousSampling class NextTokenChooser: def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + self, + watermark=False, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + typical_p=None, + do_sample=False, + seed=0, + device="cpu", ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -116,10 +39,10 @@ class NextTokenChooser: ) has_warpers = ( - (temperature is not None and temperature != 1.0) - or (top_k is not None and top_k != 0) - or (top_p is not None and top_p < 1.0) - or (typical_p is not None and typical_p < 1.0) + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) ) if has_warpers: self.static_warper = static_warper( @@ -148,9 +71,9 @@ class NextTokenChooser: @classmethod def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -178,11 +101,11 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -208,9 +131,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences @@ -221,3 +144,99 @@ class StoppingCriteria: pb.max_new_tokens, pb.ignore_eos_token, ) + + +class HeterogeneousNextTokenChooser: + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + watermark: List[bool], + temperature: List[float], + repetition_penalty: List[float], + top_k: List[int], + top_p: List[float], + typical_p: List[float], + do_sample: List[bool], + seeds: List[int], + ): + warpers = LogitsProcessorList() + + if any(watermark): + raise NotImplementedError("Watermarking not implemented") + + if any([x != 1.0 for x in repetition_penalty]): + warpers.append( + HeterogeneousRepetitionPenaltyLogitsProcessor( + repetition_penalty, dtype, device + ) + ) + + if any([x != 1.0 for x in temperature]): + do_sample = [ + sample or x != 1.0 for x, sample in zip(temperature, do_sample) + ] + warpers.append( + HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) + ) + + if any([x != 0 for x in top_k]): + do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] + warpers.append(HeterogeneousTopKLogitsWarper(top_k, dtype, device)) + + if any([x < 1.0 for x in top_p]): + do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] + warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) + + if any([x < 1.0 for x in typical_p]): + do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] + warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) + + self.warpers = warpers + + num_do_sample = sum(do_sample) + if num_do_sample == 0: + self.choice = Greedy() + else: + self.choice = HeterogeneousSampling(do_sample, seeds, device) + + self.seeds = seeds + self.do_sample = do_sample + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): + last_token_scores = self.warpers(input_ids, scores) + next_ids = self.choice(last_token_scores) + next_logprobs = torch.gather( + torch.log_softmax(last_token_scores, -1), 1, next_ids.view(-1, 1) + ).view(-1) + + return next_ids, next_logprobs + + def filter(self, indices): + for warper in self.warpers: + warper.filter(indices) + if isinstance(self.choice, HeterogeneousSampling): + self.choice.filter(indices) + self.seeds = [self.seeds[i] for i in indices] + self.do_sample = [self.do_sample[i] for i in indices] + return self + + @classmethod + def from_pb( + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + dtype: torch.dtype, + device: torch.device, + ) -> "HeterogeneousNextTokenChooser": + return HeterogeneousNextTokenChooser( + watermark=[pb_.watermark for pb_ in pb], + temperature=[pb_.temperature for pb_ in pb], + repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + top_k=[pb_.top_k for pb_ in pb], + top_p=[pb_.top_p for pb_ in pb], + typical_p=[pb_.typical_p for pb_ in pb], + do_sample=[pb_.do_sample for pb_ in pb], + seeds=[pb_.seed for pb_ in pb], + device=device, + dtype=dtype, + )