mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: remove small unneeded changes
This commit is contained in:
parent
f1d43f2df4
commit
ffc228831c
@ -22,13 +22,6 @@ from text_generation_server.utils.logits_process import (
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
from functools import lru_cache
|
||||
|
||||
# TODO: remove when done debugging
|
||||
import time
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
@ -481,7 +474,7 @@ class Sampling:
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits, *args):
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
@ -490,13 +483,9 @@ class Sampling:
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits, *args):
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
def filter(self, indices, *args):
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
|
Loading…
Reference in New Issue
Block a user