diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 38941067..045f7100 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -36,11 +36,11 @@ class Greedy: class StaticWarper: def __init__( - self, - temperature=1.0, - top_k=None, - top_p=None, - typical_p=None, + self, + temperature=1.0, + top_k=None, + top_p=None, + typical_p=None, ): self.warpers = [] @@ -64,18 +64,14 @@ class StaticWarper: self.static_scores = scores self.cuda_graph = torch.cuda.CUDAGraph() - capture_stream = torch.cuda.stream(torch.cuda.Stream()) - capture_stream.__enter__() - self.cuda_graph.capture_begin() + with torch.cuda.graph(self.cuda_graph): + for warper in self.warpers: + self.static_warped_scores = warper(None, self.static_scores) - for warper in self.warpers: - self.static_warped_scores = warper(None, self.static_scores) - - # Compute logprobs - self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1) - - self.cuda_graph.capture_end() - capture_stream.__exit__(None, None, None) + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) self.static_scores.copy_(scores) self.cuda_graph.replay() @@ -84,34 +80,51 @@ class StaticWarper: @lru_cache(10) -def static_warper(temperature: Optional[float], top_k: Optional[int], top_p: Optional[float], - typical_p: Optional[float]) -> StaticWarper: - return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) +def static_warper( + temperature: Optional[float], + top_k: Optional[int], + top_p: Optional[float], + typical_p: Optional[float], +) -> StaticWarper: + return StaticWarper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) class NextTokenChooser: def __init__( - self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + self, + watermark=False, + temperature=1.0, + repetition_penalty=1.0, + top_k=None, + top_p=None, + typical_p=None, + do_sample=False, + seed=0, + device="cpu", ): - self.watermark_warper = WatermarkLogitsProcessor(device=device) if watermark else None - self.repetition_warper = RepetitionPenaltyLogitsProcessor( - penalty=repetition_penalty) if repetition_penalty else None + self.watermark_warper = ( + WatermarkLogitsProcessor(device=device) if watermark else None + ) + self.repetition_warper = ( + RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) + if repetition_penalty + else None + ) - sampling = do_sample or (temperature is not None and temperature != 1.0) or ( - top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or ( - typical_p is not None and typical_p < 1.0) + sampling = ( + do_sample + or (temperature is not None and temperature != 1.0) + or (top_k is not None and top_k != 0) + or (top_p is not None and top_p < 1.0) + or (typical_p is not None and typical_p < 1.0) + ) if sampling: self.choice = Sampling(seed, device) - self.static_warper = static_warper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) + self.static_warper = static_warper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) else: self.choice = Greedy() self.static_warper = None @@ -133,9 +146,9 @@ class NextTokenChooser: @classmethod def from_pb( - cls, - pb: generate_pb2.NextTokenChooserParameters, - device: torch.device, + cls, + pb: generate_pb2.NextTokenChooserParameters, + device: torch.device, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -163,11 +176,11 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, - eos_token_id: int, - stop_sequence_criterias: List[StopSequenceCriteria], - max_new_tokens: int = 20, - ignore_eos_token: bool = False, + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens: int = 20, + ignore_eos_token: bool = False, ): self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias @@ -193,9 +206,9 @@ class StoppingCriteria: @classmethod def from_pb( - cls, - pb: generate_pb2.StoppingCriteriaParameters, - tokenizer: PreTrainedTokenizerBase, + cls, + pb: generate_pb2.StoppingCriteriaParameters, + tokenizer: PreTrainedTokenizerBase, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences