fix: remove small unneeded changes

This commit is contained in:
drbh 2024-02-10 01:43:31 +00:00
parent f1d43f2df4
commit ffc228831c

View File

@ -22,13 +22,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from functools import lru_cache
# TODO: remove when done debugging
import time
class NextTokenChooser:
def __init__(
@ -481,7 +474,7 @@ class Sampling:
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits, *args):
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
@ -490,13 +483,9 @@ class Sampling:
class Greedy:
def __call__(self, logits, *args):
def __call__(self, logits):
return logits.argmax(dim=-1)
def filter(self, indices, *args):
return self
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.