From e2727387aa565b9e5f0220d715029f3bb8bc1608 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 9 May 2023 16:30:19 +0200 Subject: [PATCH] add cuda graphs to token warping --- server/text_generation_server/utils/tokens.py | 142 ++++++++++++------ 1 file changed, 94 insertions(+), 48 deletions(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d5a77170..38941067 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,8 +1,8 @@ import re import torch +from functools import lru_cache from transformers import ( - LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, @@ -34,62 +34,108 @@ class Greedy: return logits.argmax() -class NextTokenChooser: +class StaticWarper: def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + self, + temperature=1.0, + top_k=None, + top_p=None, + typical_p=None, ): - warpers = LogitsProcessorList() - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - sampling = do_sample + self.warpers = [] - if watermark: - warpers.append(WatermarkLogitsProcessor(device=device)) - if repetition_penalty is not None and repetition_penalty != 1.0: - warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if temperature is not None and temperature != 1.0: temperature = float(temperature) - warpers.append(TemperatureLogitsWarper(temperature)) - sampling = True + self.warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: - warpers.append(TopKLogitsWarper(top_k=top_k)) - sampling = True + self.warpers.append(TopKLogitsWarper(top_k=top_k)) if top_p is not None and top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=top_p)) - sampling = True + self.warpers.append(TopPLogitsWarper(top_p=top_p)) if typical_p is not None and typical_p < 1.0: - warpers.append(TypicalLogitsWarper(mass=typical_p)) - sampling = True + self.warpers.append(TypicalLogitsWarper(mass=typical_p)) - self.warpers = warpers - self.choice = Sampling(seed, device) if sampling else Greedy() + self.cuda_graph = None + self.static_scores = None + self.static_warped_scores = None + self.static_next_logprob = None + + def __call__(self, scores): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() + + capture_stream = torch.cuda.stream(torch.cuda.Stream()) + capture_stream.__enter__() + self.cuda_graph.capture_begin() + + for warper in self.warpers: + self.static_warped_scores = warper(None, self.static_scores) + + # Compute logprobs + self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1) + + self.cuda_graph.capture_end() + capture_stream.__exit__(None, None, None) + + self.static_scores.copy_(scores) + self.cuda_graph.replay() + + return self.static_warped_scores, self.static_next_logprob + + +@lru_cache(10) +def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float], + typical_p: Optional[float]) -> StaticWarper: + return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + + +class NextTokenChooser: + def __init__( + self, + watermark=False, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + typical_p=None, + do_sample=False, + seed=0, + device="cpu", + ): + self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None + self.repetition_warper = RepetitionPenaltyLogitsProcessor( + penalty=repetition_penalty) if repetition_penalty else None + + sampling = do_sample or (temperature is not None and temperature != 1.0) or ( + top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or ( + typical_p is not None and typical_p < 1.0) + if sampling: + self.choice = Sampling(seed, device) + self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + else: + self.choice = Greedy() + self.static_warper = None def __call__(self, input_ids, scores): - # Warp logits - scores = self.warpers(input_ids, scores) + if self.watermark_warper: + scores = self.watermark_warper(input_ids, scores) + if self.repetition_warper: + scores = self.repetition_warper(input_ids, scores) - # Compute logprobs - logprobs = torch.log_softmax(scores, -1) + if self.static_warper is None: + next_logprob = torch.log_softmax(scores, -1) + else: + scores, next_logprob = self.static_warper(scores) - # Choose tokens - next_id = self.choice(scores[-1]) + next_id = self.choice(scores[-1]).view(1, 1) - return next_id.view(1, 1), logprobs + return next_id, next_logprob @classmethod def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -117,11 +163,11 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -147,9 +193,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences