diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index fb98386f..6afbc7e4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -18,10 +18,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - StoppingCriteria, - HeterogeneousNextTokenChooser -) +from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser tracer = trace.get_tracer(__name__) @@ -71,11 +68,11 @@ class FlashCausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": position_ids = [] cu_seqlens = [0] @@ -228,7 +225,7 @@ class FlashCausalLMBatch(Batch): # Slice from past past_key_values.append( - self.past_key_values[:, self.cu_seqlens[idx]: self.cu_seqlens[idx + 1]] + self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] ) all_input_ids.append(self.all_input_ids[idx]) @@ -242,7 +239,7 @@ class FlashCausalLMBatch(Batch): cumulative_length += request_input_length max_tokens += request_input_length + ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) if single_request: @@ -395,7 +392,7 @@ class FlashCausalLMBatch(Batch): end_index = cumulative_batch_size + len(batch) all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] + start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor cumulative_batch_size += len(batch) @@ -481,14 +478,14 @@ class FlashCausalLM(Model): ) def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - max_s: int, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + max_s: int, + past_key_values: Optional = None, + pre_allocate_past_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -503,7 +500,7 @@ class FlashCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None single_request = len(batch) == 1 @@ -512,7 +509,7 @@ class FlashCausalLM(Model): # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( - batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens + batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens ) else: pre_allocate_past_size = None @@ -613,9 +610,9 @@ class FlashCausalLM(Model): # For each member of the batch for i, ( - input_length, - stopping_criteria, - all_input_ids, + input_length, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -630,8 +627,8 @@ class FlashCausalLM(Model): # Copy batch.input_ids to prefill_token_indices if len(batch) > 1: prefill_tokens_indices[ - start_index: end_index - 1 - ] = batch.input_ids[start_index + 1: end_index] + start_index : end_index - 1 + ] = batch.input_ids[start_index + 1 : end_index] else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids @@ -717,7 +714,7 @@ class FlashCausalLM(Model): if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens:] + all_input_ids[-stopping_criteria.current_tokens :] ) generated_text = GeneratedText( output_text, @@ -732,8 +729,8 @@ class FlashCausalLM(Model): if prefill: # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - start_index: end_index - 1 - ] + start_index : end_index - 1 + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 781e59e3..6a351d66 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -14,8 +14,9 @@ from text_generation_server.utils.tokens import ( StoppingCriteria, StopSequenceCriteria, FinishReason, + Sampling, + Greedy, ) -from text_generation_server.utils.logits_process import Sampling, Greedy __all__ = [ "convert_file", diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 3206807f..47f3a33c 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -14,25 +14,6 @@ from transformers import ( ) -class Sampling: - def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator(device) - self.generator.manual_seed(seed) - self.seed = seed - - def __call__(self, logits): - probs = torch.nn.functional.softmax(logits, -1) - # Avoid GPU<->CPU sync done by torch multinomial - # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 - q = torch.empty_like(probs).exponential_(1, generator=self.generator) - return probs.div_(q).argmax() - - -class Greedy: - def __call__(self, logits): - return logits.argmax(dim=-1) - - class StaticWarper: def __init__( self, @@ -329,46 +310,3 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper): def filter(self, indices): self.mass = self.mass[indices] return self - - -class HeterogeneousSampling: - r""" - Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. - """ - - def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): - self.seeds = seeds - - self.greedy_indices = [] - self.sampling_mapping = {} - for i, (sample, seed) in enumerate(zip(do_sample, seeds)): - if sample: - self.sampling_mapping[i] = Sampling(seed, device) - else: - self.greedy_indices.append(i) - - self.greedy = Greedy() - - def __call__(self, logits): - out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) - if self.greedy_indices: - out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1) - - for i, sampling in self.sampling_mapping.items(): - out[i] = sampling(logits[i]) - return out - - def filter(self, indices): - new_greedy_indices = [] - new_sampling_mapping = {} - for i, idx in enumerate(indices): - if idx in self.sampling_mapping: - new_sampling_mapping[i] = self.sampling_mapping[idx] - else: - new_greedy_indices.append(i) - - self.greedy_indices = new_greedy_indices - self.sampling_mapping = new_sampling_mapping - return self - - diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e3b28b90..c4e45b72 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -3,31 +3,36 @@ import torch from transformers import ( RepetitionPenaltyLogitsProcessor, - PreTrainedTokenizerBase, LogitsProcessorList, + PreTrainedTokenizerBase, + LogitsProcessorList, ) from typing import List, Tuple, Optional from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.watermark import WatermarkLogitsProcessor -from text_generation_server.utils import Sampling, Greedy -from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \ - HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \ - HeterogeneousTypicalLogitsWarper, HeterogeneousSampling +from text_generation_server.utils.logits_process import ( + static_warper, + HeterogeneousRepetitionPenaltyLogitsProcessor, + HeterogeneousTemperatureLogitsWarper, + HeterogeneousTopKLogitsWarper, + HeterogeneousTopPLogitsWarper, + HeterogeneousTypicalLogitsWarper, +) class NextTokenChooser: def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + self, + watermark=False, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + typical_p=None, + do_sample=False, + seed=0, + device="cpu", ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -39,10 +44,10 @@ class NextTokenChooser: ) has_warpers = ( - (temperature is not None and temperature != 1.0) - or (top_k is not None and top_k != 0) - or (top_p is not None and top_p < 1.0) - or (typical_p is not None and typical_p < 1.0) + (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) ) if has_warpers: self.static_warper = static_warper( @@ -71,9 +76,9 @@ class NextTokenChooser: @classmethod def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -101,11 +106,11 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -131,9 +136,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences @@ -148,17 +153,17 @@ class StoppingCriteria: class HeterogeneousNextTokenChooser: def __init__( - self, - dtype: torch.dtype, - device: torch.device, - watermark: List[bool], - temperature: List[float], - repetition_penalty: List[float], - top_k: List[int], - top_p: List[float], - typical_p: List[float], - do_sample: List[bool], - seeds: List[int], + self, + dtype: torch.dtype, + device: torch.device, + watermark: List[bool], + temperature: List[float], + repetition_penalty: List[float], + top_k: List[int], + top_p: List[float], + typical_p: List[float], + do_sample: List[bool], + seeds: List[int], ): warpers = LogitsProcessorList() @@ -223,10 +228,10 @@ class HeterogeneousNextTokenChooser: @classmethod def from_pb( - cls, - pb: List[generate_pb2.NextTokenChooserParameters], - dtype: torch.dtype, - device: torch.device, + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + dtype: torch.dtype, + device: torch.device, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -240,3 +245,63 @@ class HeterogeneousNextTokenChooser: device=device, dtype=dtype, ) + + +class Sampling: + def __init__(self, seed: int, device: str = "cpu"): + self.generator = torch.Generator(device) + self.generator.manual_seed(seed) + self.seed = seed + + def __call__(self, logits): + probs = torch.nn.functional.softmax(logits, -1) + # Avoid GPU<->CPU sync done by torch multinomial + # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 + q = torch.empty_like(probs).exponential_(1, generator=self.generator) + return probs.div_(q).argmax() + + +class Greedy: + def __call__(self, logits): + return logits.argmax(dim=-1) + + +class HeterogeneousSampling: + r""" + Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. + """ + + def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): + self.seeds = seeds + + self.greedy_indices = [] + self.sampling_mapping = {} + for i, (sample, seed) in enumerate(zip(do_sample, seeds)): + if sample: + self.sampling_mapping[i] = Sampling(seed, device) + else: + self.greedy_indices.append(i) + + self.greedy = Greedy() + + def __call__(self, logits): + out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) + if self.greedy_indices: + out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1) + + for i, sampling in self.sampling_mapping.items(): + out[i] = sampling(logits[i]) + return out + + def filter(self, indices): + new_greedy_indices = [] + new_sampling_mapping = {} + for i, idx in enumerate(indices): + if idx in self.sampling_mapping: + new_sampling_mapping[i] = self.sampling_mapping[idx] + else: + new_greedy_indices.append(i) + + self.greedy_indices = new_greedy_indices + self.sampling_mapping = new_sampling_mapping + return self