diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 789e2e85..ae14b337 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -29,7 +29,7 @@ class Sampling: q = torch.empty_like(probs).exponential_(1, generator=self.generator) torch.div(probs, q, out=q) - return torch.argmax(q, dim=-1, keepdim=True) + return q.argmax() class Greedy: @@ -107,36 +107,36 @@ class NextTokenChooser: seed=0, device="cpu", ): - self.watermark_warper = ( + self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None ) - self.repetition_warper = ( + self.repetition_processor = ( RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty else None ) - sampling = ( - do_sample - or (temperature is not None and temperature != 1.0) + has_warpers = ( + (temperature is not None and temperature != 1.0) or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) ) - if sampling: - self.choice = Sampling(seed, device) + if has_warpers: self.static_warper = static_warper( temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p ) else: - self.choice = Greedy() self.static_warper = None + sampling = do_sample or has_warpers + self.choice = Sampling(seed, device) if sampling else Greedy() + def __call__(self, input_ids, scores): - if self.watermark_warper: - scores = self.watermark_warper(input_ids, scores) - if self.repetition_warper: - scores = self.repetition_warper(input_ids, scores) + if self.watermark_processor: + scores = self.watermark_processor(input_ids, scores) + if self.repetition_processor: + scores = self.repetition_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1)