mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Refactor next token chooser
This commit is contained in:
parent
e29bb90e88
commit
0e648a71f9
@ -3,10 +3,10 @@ import os
|
||||
import math
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict, Union
|
||||
from loguru import logger
|
||||
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -18,6 +18,7 @@ from text_generation_server.models.types import (
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria
|
||||
from text_generation_server.utils.tokens_heterogeneous import HeterogeneousNextTokenChooser
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -42,7 +43,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
token_offsets: List[Optional[int]]
|
||||
|
||||
# Generation helpers
|
||||
next_token_chooser: "VectorizedNextTokenChooser"
|
||||
next_token_chooser: "HeterogeneousNextTokenChooser"
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
|
||||
# Metadata used for padding
|
||||
@ -81,7 +82,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
|
||||
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
|
||||
|
||||
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
||||
next_token_chooser= HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
||||
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
@ -146,7 +147,9 @@ class VectorizedCausalLMBatch(Batch):
|
||||
self.input_lengths=[self.input_lengths[i] for i in keep_indices]
|
||||
self.offsets = [self.offsets[i] for i in keep_indices]
|
||||
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
|
||||
self.next_token_chooser=self.next_token_chooser.filter(keep_indices)
|
||||
|
||||
self.next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in self.requests], self.input_ids.device)
|
||||
|
||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
||||
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
||||
|
||||
@ -194,7 +197,7 @@ class VectorizedCausalLMBatch(Batch):
|
||||
input_lengths = [length for batch in batches for length in batch.input_lengths]
|
||||
offsets = [offset for batch in batches for offset in batch.offsets]
|
||||
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
|
||||
next_token_chooser=VectorizedNextTokenChooser.concatenate([batch.next_token_chooser for batch in batches])
|
||||
next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in requests], batches[0].input_ids.device)
|
||||
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
|
||||
|
||||
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
|
||||
@ -290,218 +293,6 @@ class VectorizedCausalLMBatch(Batch):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class VectorizedNextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size:int,
|
||||
watermark:Optional[List[Optional[bool]]]=None,
|
||||
temperature:Optional[List[Optional[float]]]=None,
|
||||
repetition_penalty:Optional[List[Optional[float]]]=None,
|
||||
top_k:Optional[List[Optional[int]]]=None,
|
||||
top_p:Optional[List[Optional[float]]]=None,
|
||||
typical_p:Optional[List[Optional[float]]]=None,
|
||||
do_sample:Optional[List[Optional[bool]]]=None,
|
||||
seeds:Optional[List[Optional[int]]]=None,
|
||||
device:torch.device="cpu",
|
||||
):
|
||||
self.batch_size=batch_size
|
||||
self.filter_value = -math.inf
|
||||
self.device=device
|
||||
|
||||
# TODO: Seeds are ignored
|
||||
self.seeds=self._standardize(seeds, 0)
|
||||
self.do_sample=self._standardize(do_sample, False)
|
||||
|
||||
self.watermark=self._standardize(watermark, False)
|
||||
if any(self.watermark):
|
||||
raise NotImplementedError("Watermarking not implemented")
|
||||
|
||||
self.repetition_penalty=self._standardize(repetition_penalty, 1.0)
|
||||
if any([x!=1.0 for x in self.repetition_penalty]):
|
||||
self.repetition_penalty_t=torch.tensor(self.repetition_penalty, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.repetition_penalty_t=None
|
||||
|
||||
self.temperature=self._standardize(temperature, 1.0)
|
||||
if any([x!=1.0 for x in self.temperature]):
|
||||
self.do_sample=[sample or x!=1.0 for x, sample in zip(self.temperature, self.do_sample)]
|
||||
self.temperature_t=torch.tensor(self.temperature, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.temperature_t=None
|
||||
|
||||
self.top_k=self._standardize(top_k, 0)
|
||||
n_top_k=sum([x!=0 for x in top_k])
|
||||
if n_top_k>0:
|
||||
self.do_sample=[sample or x!=0 for x, sample in zip(self.top_k, self.do_sample)]
|
||||
self.max_top_k=max(self.top_k)
|
||||
self.top_k_t=torch.tensor([max(x-1,0) for x in self.top_k], dtype=torch.int64, device=self.device).unsqueeze(1)
|
||||
if n_top_k<self.batch_size:
|
||||
self.top_k_mask=torch.tensor([x==0 for x in self.top_k], dtype=torch.bool, device=self.device)
|
||||
else:
|
||||
self.top_k_mask=None
|
||||
else:
|
||||
self.max_top_k=None
|
||||
self.top_k_t=None
|
||||
self.top_k_mask=None
|
||||
|
||||
self.top_p=self._standardize(top_p, 1.0)
|
||||
if any([x<1.0 for x in self.top_p]):
|
||||
self.do_sample=[sample or x<1.0 for x, sample in zip(temperature, self.top_p)]
|
||||
self.top_p_t=torch.tensor([1.0-x for x in self.top_p], dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.top_p_t=None
|
||||
|
||||
self.typical_p=self._standardize(typical_p, 1.0)
|
||||
if any([x<1.0 for x in self.typical_p]):
|
||||
self.do_sample=[sample or x<1.0 for x, sample in zip(self.typical_p, self.do_sample)]
|
||||
self.typical_p_t=torch.tensor(self.typical_p, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
else:
|
||||
self.typical_p_t=None
|
||||
|
||||
self.num_do_sample=sum(self.do_sample)
|
||||
if 0<self.num_do_sample<self.batch_size:
|
||||
# Mixed greedy and probabilistic sampling. Compute both and pick the right one.
|
||||
self.do_sample_t=torch.tensor(self.do_sample, dtype=torch.bool, device=self.device)
|
||||
else:
|
||||
self.do_sample_t=None
|
||||
|
||||
def _standardize(self, values, default):
|
||||
if isinstance(values, list):
|
||||
values=values.copy()
|
||||
else:
|
||||
values=[values]*self.batch_size
|
||||
assert len(values)==self.batch_size
|
||||
for i, v in enumerate(values):
|
||||
if v is None:
|
||||
values[i]=default
|
||||
return values
|
||||
|
||||
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
|
||||
last_token_scores=scores[:, -1, :]
|
||||
|
||||
if self.repetition_penalty_t is not None:
|
||||
score = torch.gather(last_token_scores, 1, input_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t)
|
||||
last_token_scores.scatter_(1, input_ids, score)
|
||||
|
||||
if self.temperature_t is not None:
|
||||
last_token_scores.div_(self.temperature_t)
|
||||
|
||||
if self.top_k_t is not None:
|
||||
if last_token_scores.size(-1)>self.max_top_k: # Safety check
|
||||
max_top_k=last_token_scores.size(-1)
|
||||
top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed.
|
||||
else:
|
||||
max_top_k=self.max_top_k
|
||||
top_k=self.top_k_t
|
||||
kth_scores=torch.gather(torch.topk(last_token_scores, max_top_k)[0], 1, top_k)
|
||||
if self.top_k_mask is not None:
|
||||
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = last_token_scores < kth_scores
|
||||
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
if self.top_p_t is not None:
|
||||
# TODO: Merge wit top_k
|
||||
sorted_logits, sorted_indices = torch.sort(last_token_scores, descending=True)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_t
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
if self.typical_p_t is not None:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(last_token_scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = last_token_scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
|
||||
if self.num_do_sample:
|
||||
probs = torch.nn.functional.softmax(last_token_scores, -1)
|
||||
next_token_ids = torch.multinomial(probs, num_samples=1)
|
||||
if self.do_sample_t is not None:
|
||||
next_token_ids=torch.where(self.do_sample_t, next_token_ids, torch.argmax(last_token_scores, dim=-1))
|
||||
else:
|
||||
next_token_ids = torch.argmax(last_token_scores, dim=-1)
|
||||
|
||||
if return_logprobs:
|
||||
# Compute logprobs
|
||||
if scores.size(1)==1:
|
||||
scores=last_token_scores.unsqueeze(1)
|
||||
else:
|
||||
# TODO: Post-process all the tokens?
|
||||
scores[:, -1, :]=last_token_scores
|
||||
|
||||
logprobs = torch.log_softmax(scores, dim=-1)
|
||||
else:
|
||||
logprobs=None
|
||||
|
||||
return next_token_ids, logprobs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
device: torch.device,
|
||||
) -> "VectorizedNextTokenChooser":
|
||||
# TODO: Seeds are ignored
|
||||
return VectorizedNextTokenChooser(
|
||||
batch_size=len(pb),
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
top_k=[pb_.top_k for pb_ in pb],
|
||||
top_p=[pb_.top_p for pb_ in pb],
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
do_sample=[pb_.do_sample for pb_ in pb],
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
)
|
||||
|
||||
def filter(self, keep_indices: List[int]) -> "VectorizedNextTokenChooser":
|
||||
return VectorizedNextTokenChooser(
|
||||
batch_size=len(keep_indices),
|
||||
watermark=[self.watermark[i] for i in keep_indices],
|
||||
temperature=[self.temperature[i] for i in keep_indices],
|
||||
repetition_penalty=[self.repetition_penalty[i] for i in keep_indices],
|
||||
top_k=[self.top_k[i] for i in keep_indices],
|
||||
top_p=[self.top_p[i] for i in keep_indices],
|
||||
typical_p=[self.typical_p[i] for i in keep_indices],
|
||||
do_sample=[self.do_sample[i] for i in keep_indices],
|
||||
seeds=[self.seeds[i] for i in keep_indices],
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, next_token_choosers: List["VectorizedNextTokenChooser"]) -> "VectorizedNextTokenChooser":
|
||||
return cls(
|
||||
batch_size=sum(next_token_chooser.batch_size for next_token_chooser in next_token_choosers),
|
||||
watermark=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.watermark],
|
||||
temperature=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.temperature],
|
||||
repetition_penalty=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.repetition_penalty],
|
||||
top_k=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_k],
|
||||
top_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_p],
|
||||
typical_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.typical_p],
|
||||
do_sample=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.do_sample],
|
||||
seeds=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.seeds],
|
||||
device=next_token_choosers[0].device,
|
||||
)
|
||||
|
||||
|
||||
class VectorizedCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
@ -633,7 +424,7 @@ class VectorizedCausalLM(Model):
|
||||
|
||||
generation = Generation(
|
||||
batch.requests[i].id,
|
||||
prefill_tokens[i] if batch.details and query_length>1 else None, # TODO: Prefill tokens
|
||||
prefill_tokens[i] if batch.details and query_length>1 else None,
|
||||
next_token_id,
|
||||
token_logprobs[i] if batch.details else 0.0,
|
||||
next_token_text,
|
||||
|
@ -32,7 +32,7 @@ class Sampling:
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax()
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
|
302
server/text_generation_server/utils/tokens_heterogeneous.py
Normal file
302
server/text_generation_server/utils/tokens_heterogeneous.py
Normal file
@ -0,0 +1,302 @@
|
||||
import math
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
import torch
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils.tokens import Greedy, Sampling
|
||||
|
||||
|
||||
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
repetition_penalty (`List[float]`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: List[float], device:torch.device):
|
||||
self.penalty = torch.tensor(penalty, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: List[float], device:torch.device):
|
||||
self.temperature = torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores.div_(self.temperature)
|
||||
return scores
|
||||
|
||||
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
||||
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: List[int], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
||||
self.max_top_k = max(top_k)
|
||||
self.top_k = torch.tensor([max(x - 1, min_tokens_to_keep-1) for x in top_k], dtype=torch.int64,device=device).unsqueeze(1)
|
||||
zeros=[x == 0 for x in top_k]
|
||||
if any(zeros):
|
||||
self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device)
|
||||
else:
|
||||
self.top_k_mask = None
|
||||
self.filter_value = filter_value
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
if scores.size(-1)>self.max_top_k: # Safety check
|
||||
max_top_k=scores.size(-1)
|
||||
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed.
|
||||
else:
|
||||
max_top_k=self.max_top_k
|
||||
top_k=self.top_k
|
||||
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||
if self.top_k_mask is not None:
|
||||
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < kth_scores
|
||||
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
This version allows for a separate value for each sample and runs inplace when possible.
|
||||
It doesn't validate inputs.
|
||||
|
||||
Args:
|
||||
mass (`float`):
|
||||
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, mass: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
||||
self.filter_value = filter_value
|
||||
self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1)
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||
|
||||
# shift and sort
|
||||
shifted_scores = torch.abs((-normalized) - ent)
|
||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||
sorted_logits = scores.gather(-1, sorted_indices)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
|
||||
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||
"""
|
||||
def __init__(self, do_sample:List[bool], seeds: List[int], device: torch.device):
|
||||
self.seeds = seeds
|
||||
self.greedy=Greedy()
|
||||
# TODO: Most seeds are ignored
|
||||
self.sampling=Sampling(seeds[0], device)
|
||||
self.do_sample=torch.tensor(do_sample, dtype=torch.bool, device=device)
|
||||
|
||||
def __call__(self, logits):
|
||||
return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits))
|
||||
|
||||
class HeterogeneousNextTokenChooser:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
batch_size:int,
|
||||
device:torch.device,
|
||||
watermark:Optional[Union[bool,List[Optional[bool]]]]=None,
|
||||
temperature:Optional[Union[float,List[Optional[float]]]]=None,
|
||||
repetition_penalty:Optional[Union[float,List[Optional[float]]]]=None,
|
||||
top_k:Optional[Union[int,List[Optional[int]]]]=None,
|
||||
top_p:Optional[Union[float,List[Optional[float]]]]=None,
|
||||
typical_p:Optional[Union[float,List[Optional[float]]]]=None,
|
||||
do_sample:Optional[Union[bool,List[Optional[bool]]]]=None,
|
||||
seeds:Optional[Union[int,List[Optional[int]]]]=None,
|
||||
):
|
||||
# TODO: Most seeds are ignored
|
||||
seeds=self._standardize(seeds, batch_size, 0)
|
||||
do_sample=self._standardize(do_sample, batch_size, False)
|
||||
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
watermark=self._standardize(watermark, batch_size, False)
|
||||
if any(watermark):
|
||||
raise NotImplementedError("Watermarking not implemented")
|
||||
|
||||
repetition_penalty=self._standardize(repetition_penalty, batch_size, 1.0)
|
||||
if any([x!=1.0 for x in repetition_penalty]):
|
||||
warpers.append(HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, device))
|
||||
|
||||
temperature=self._standardize(temperature, batch_size, 1.0)
|
||||
if any([x!=1.0 for x in temperature]):
|
||||
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
|
||||
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device))
|
||||
|
||||
top_k=self._standardize(top_k, batch_size, 0)
|
||||
n_top_k=sum([x!=0 for x in top_k])
|
||||
if n_top_k>0:
|
||||
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||
|
||||
top_p=self._standardize(top_p, batch_size, 1.0)
|
||||
if any([x<1.0 for x in top_p]):
|
||||
do_sample=[sample or x<1.0 for x, sample in zip(top_p, do_sample)]
|
||||
warpers.append(HeterogeneousTopPLogitsWarper(top_p, device))
|
||||
|
||||
typical_p=self._standardize(typical_p, batch_size, 1.0)
|
||||
if any([x<1.0 for x in typical_p]):
|
||||
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
|
||||
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device))
|
||||
|
||||
self.warpers=warpers
|
||||
|
||||
num_do_sample=sum(do_sample)
|
||||
if num_do_sample==0:
|
||||
self.choice=Greedy()
|
||||
elif num_do_sample<batch_size:
|
||||
self.choice=HeterogeneousSampling(do_sample, seeds, device)
|
||||
else:
|
||||
# TODO: Most seeds are ignored
|
||||
self.choice=Sampling(seeds[0], device)
|
||||
|
||||
@staticmethod
|
||||
def _standardize(values, batch_size, default):
|
||||
if isinstance(values, list):
|
||||
values=values.copy()
|
||||
else:
|
||||
values=[values]*batch_size
|
||||
assert len(values)==batch_size
|
||||
for i, v in enumerate(values):
|
||||
if v is None:
|
||||
values[i]=default
|
||||
return values
|
||||
|
||||
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
|
||||
last_token_scores=self.warpers(input_ids, scores[:, -1, :])
|
||||
next_token_ids=self.choice(last_token_scores)
|
||||
|
||||
if return_logprobs:
|
||||
# Compute logprobs
|
||||
if scores.size(1)==1:
|
||||
scores=last_token_scores.unsqueeze(1)
|
||||
else:
|
||||
# TODO: Post-process all the tokens?
|
||||
scores[:, -1, :]=last_token_scores
|
||||
logprobs = torch.log_softmax(scores, dim=-1)
|
||||
else:
|
||||
logprobs=None
|
||||
|
||||
return next_token_ids, logprobs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
device: torch.device,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
# TODO: Seeds are ignored
|
||||
return HeterogeneousNextTokenChooser(
|
||||
batch_size=len(pb),
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
top_k=[pb_.top_k for pb_ in pb],
|
||||
top_p=[pb_.top_p for pb_ in pb],
|
||||
typical_p=[pb_.typical_p for pb_ in pb],
|
||||
do_sample=[pb_.do_sample for pb_ in pb],
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
)
|
Loading…
Reference in New Issue
Block a user