fix: remove small unneeded changes

This commit is contained in:
drbh 2024-02-10 01:43:31 +00:00
parent f1d43f2df4
commit ffc228831c

View File

@ -22,13 +22,6 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from functools import lru_cache
# TODO: remove when done debugging
import time
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
@ -481,7 +474,7 @@ class Sampling:
self.generator.manual_seed(seed) self.generator.manual_seed(seed)
self.seed = seed self.seed = seed
def __call__(self, logits, *args): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1) probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial # Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
@ -490,13 +483,9 @@ class Sampling:
class Greedy: class Greedy:
def __call__(self, logits, *args): def __call__(self, logits):
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
def filter(self, indices, *args):
return self
class HeterogeneousSampling: class HeterogeneousSampling:
r""" r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.