mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix imports
This commit is contained in:
parent
f9e3a3bb91
commit
e7826855a3
@ -18,10 +18,7 @@ from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
StoppingCriteria,
|
||||
HeterogeneousNextTokenChooser
|
||||
)
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -14,8 +14,9 @@ from text_generation_server.utils.tokens import (
|
||||
StoppingCriteria,
|
||||
StopSequenceCriteria,
|
||||
FinishReason,
|
||||
Sampling,
|
||||
Greedy,
|
||||
)
|
||||
from text_generation_server.utils.logits_process import Sampling, Greedy
|
||||
|
||||
__all__ = [
|
||||
"convert_file",
|
||||
|
@ -14,25 +14,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
return probs.div_(q).argmax()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
|
||||
class StaticWarper:
|
||||
def __init__(
|
||||
self,
|
||||
@ -329,46 +310,3 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
def filter(self, indices):
|
||||
self.mass = self.mass[indices]
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
"""
|
||||
|
||||
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||
self.seeds = seeds
|
||||
|
||||
self.greedy_indices = []
|
||||
self.sampling_mapping = {}
|
||||
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
||||
if sample:
|
||||
self.sampling_mapping[i] = Sampling(seed, device)
|
||||
else:
|
||||
self.greedy_indices.append(i)
|
||||
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
|
||||
|
||||
for i, sampling in self.sampling_mapping.items():
|
||||
out[i] = sampling(logits[i])
|
||||
return out
|
||||
|
||||
def filter(self, indices):
|
||||
new_greedy_indices = []
|
||||
new_sampling_mapping = {}
|
||||
for i, idx in enumerate(indices):
|
||||
if idx in self.sampling_mapping:
|
||||
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
||||
else:
|
||||
new_greedy_indices.append(i)
|
||||
|
||||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
||||
|
||||
|
@ -3,17 +3,22 @@ import torch
|
||||
|
||||
from transformers import (
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
PreTrainedTokenizerBase, LogitsProcessorList,
|
||||
PreTrainedTokenizerBase,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.utils import Sampling, Greedy
|
||||
from text_generation_server.utils.logits_process import static_warper, HeterogeneousRepetitionPenaltyLogitsProcessor, \
|
||||
HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, \
|
||||
HeterogeneousTypicalLogitsWarper, HeterogeneousSampling
|
||||
from text_generation_server.utils.logits_process import (
|
||||
static_warper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousTemperatureLogitsWarper,
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -240,3 +245,63 @@ class HeterogeneousNextTokenChooser:
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||
return probs.div_(q).argmax()
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
"""
|
||||
|
||||
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
||||
self.seeds = seeds
|
||||
|
||||
self.greedy_indices = []
|
||||
self.sampling_mapping = {}
|
||||
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
||||
if sample:
|
||||
self.sampling_mapping[i] = Sampling(seed, device)
|
||||
else:
|
||||
self.greedy_indices.append(i)
|
||||
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
out[self.greedy_indices] = torch.argmax(logits[self.greedy_indices], -1)
|
||||
|
||||
for i, sampling in self.sampling_mapping.items():
|
||||
out[i] = sampling(logits[i])
|
||||
return out
|
||||
|
||||
def filter(self, indices):
|
||||
new_greedy_indices = []
|
||||
new_sampling_mapping = {}
|
||||
for i, idx in enumerate(indices):
|
||||
if idx in self.sampling_mapping:
|
||||
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
||||
else:
|
||||
new_greedy_indices.append(i)
|
||||
|
||||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
Loading…
Reference in New Issue
Block a user