mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Refactor next token chooser
This commit is contained in:
parent
e29bb90e88
commit
0e648a71f9
@ -3,10 +3,10 @@ import os
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict, Union
|
from typing import Optional, Tuple, List, Type, Dict, Union
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -18,6 +18,7 @@ from text_generation_server.models.types import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria
|
from text_generation_server.utils import StoppingCriteria
|
||||||
|
from text_generation_server.utils.tokens_heterogeneous import HeterogeneousNextTokenChooser
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
token_offsets: List[Optional[int]]
|
token_offsets: List[Optional[int]]
|
||||||
|
|
||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_chooser: "VectorizedNextTokenChooser"
|
next_token_chooser: "HeterogeneousNextTokenChooser"
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
@ -81,7 +82,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
|
stopping_criterias = [StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) for r in pb.requests]
|
||||||
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
|
max_new_tokens=(stopping_criteria.max_new_tokens for stopping_criteria in stopping_criterias)
|
||||||
|
|
||||||
next_token_chooser=VectorizedNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
next_token_chooser= HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], device)
|
||||||
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
@ -146,7 +147,9 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
self.input_lengths=[self.input_lengths[i] for i in keep_indices]
|
self.input_lengths=[self.input_lengths[i] for i in keep_indices]
|
||||||
self.offsets = [self.offsets[i] for i in keep_indices]
|
self.offsets = [self.offsets[i] for i in keep_indices]
|
||||||
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
|
self.token_offsets = [self.token_offsets[i] for i in keep_indices]
|
||||||
self.next_token_chooser=self.next_token_chooser.filter(keep_indices)
|
|
||||||
|
self.next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in self.requests], self.input_ids.device)
|
||||||
|
|
||||||
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
|
||||||
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
remaining_decode_tokens=[stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for stopping_criteria in self.stopping_criterias]
|
||||||
|
|
||||||
@ -194,7 +197,7 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
input_lengths = [length for batch in batches for length in batch.input_lengths]
|
input_lengths = [length for batch in batches for length in batch.input_lengths]
|
||||||
offsets = [offset for batch in batches for offset in batch.offsets]
|
offsets = [offset for batch in batches for offset in batch.offsets]
|
||||||
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
|
token_offsets = [token_offset for batch in batches for token_offset in batch.token_offsets]
|
||||||
next_token_chooser=VectorizedNextTokenChooser.concatenate([batch.next_token_chooser for batch in batches])
|
next_token_chooser=HeterogeneousNextTokenChooser.from_pb([r.parameters for r in requests], batches[0].input_ids.device)
|
||||||
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
|
stopping_criterias = [stopping_criteria for batch in batches for stopping_criteria in batch.stopping_criterias]
|
||||||
|
|
||||||
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
|
requests_idx_mapping = {k: v + start_index for batch, start_index in zip(batches, start_indices) for k, v in batch.requests_idx_mapping.items()}
|
||||||
@ -290,218 +293,6 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class VectorizedNextTokenChooser:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size:int,
|
|
||||||
watermark:Optional[List[Optional[bool]]]=None,
|
|
||||||
temperature:Optional[List[Optional[float]]]=None,
|
|
||||||
repetition_penalty:Optional[List[Optional[float]]]=None,
|
|
||||||
top_k:Optional[List[Optional[int]]]=None,
|
|
||||||
top_p:Optional[List[Optional[float]]]=None,
|
|
||||||
typical_p:Optional[List[Optional[float]]]=None,
|
|
||||||
do_sample:Optional[List[Optional[bool]]]=None,
|
|
||||||
seeds:Optional[List[Optional[int]]]=None,
|
|
||||||
device:torch.device="cpu",
|
|
||||||
):
|
|
||||||
self.batch_size=batch_size
|
|
||||||
self.filter_value = -math.inf
|
|
||||||
self.device=device
|
|
||||||
|
|
||||||
# TODO: Seeds are ignored
|
|
||||||
self.seeds=self._standardize(seeds, 0)
|
|
||||||
self.do_sample=self._standardize(do_sample, False)
|
|
||||||
|
|
||||||
self.watermark=self._standardize(watermark, False)
|
|
||||||
if any(self.watermark):
|
|
||||||
raise NotImplementedError("Watermarking not implemented")
|
|
||||||
|
|
||||||
self.repetition_penalty=self._standardize(repetition_penalty, 1.0)
|
|
||||||
if any([x!=1.0 for x in self.repetition_penalty]):
|
|
||||||
self.repetition_penalty_t=torch.tensor(self.repetition_penalty, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
||||||
else:
|
|
||||||
self.repetition_penalty_t=None
|
|
||||||
|
|
||||||
self.temperature=self._standardize(temperature, 1.0)
|
|
||||||
if any([x!=1.0 for x in self.temperature]):
|
|
||||||
self.do_sample=[sample or x!=1.0 for x, sample in zip(self.temperature, self.do_sample)]
|
|
||||||
self.temperature_t=torch.tensor(self.temperature, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
||||||
else:
|
|
||||||
self.temperature_t=None
|
|
||||||
|
|
||||||
self.top_k=self._standardize(top_k, 0)
|
|
||||||
n_top_k=sum([x!=0 for x in top_k])
|
|
||||||
if n_top_k>0:
|
|
||||||
self.do_sample=[sample or x!=0 for x, sample in zip(self.top_k, self.do_sample)]
|
|
||||||
self.max_top_k=max(self.top_k)
|
|
||||||
self.top_k_t=torch.tensor([max(x-1,0) for x in self.top_k], dtype=torch.int64, device=self.device).unsqueeze(1)
|
|
||||||
if n_top_k<self.batch_size:
|
|
||||||
self.top_k_mask=torch.tensor([x==0 for x in self.top_k], dtype=torch.bool, device=self.device)
|
|
||||||
else:
|
|
||||||
self.top_k_mask=None
|
|
||||||
else:
|
|
||||||
self.max_top_k=None
|
|
||||||
self.top_k_t=None
|
|
||||||
self.top_k_mask=None
|
|
||||||
|
|
||||||
self.top_p=self._standardize(top_p, 1.0)
|
|
||||||
if any([x<1.0 for x in self.top_p]):
|
|
||||||
self.do_sample=[sample or x<1.0 for x, sample in zip(temperature, self.top_p)]
|
|
||||||
self.top_p_t=torch.tensor([1.0-x for x in self.top_p], dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
||||||
else:
|
|
||||||
self.top_p_t=None
|
|
||||||
|
|
||||||
self.typical_p=self._standardize(typical_p, 1.0)
|
|
||||||
if any([x<1.0 for x in self.typical_p]):
|
|
||||||
self.do_sample=[sample or x<1.0 for x, sample in zip(self.typical_p, self.do_sample)]
|
|
||||||
self.typical_p_t=torch.tensor(self.typical_p, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
||||||
else:
|
|
||||||
self.typical_p_t=None
|
|
||||||
|
|
||||||
self.num_do_sample=sum(self.do_sample)
|
|
||||||
if 0<self.num_do_sample<self.batch_size:
|
|
||||||
# Mixed greedy and probabilistic sampling. Compute both and pick the right one.
|
|
||||||
self.do_sample_t=torch.tensor(self.do_sample, dtype=torch.bool, device=self.device)
|
|
||||||
else:
|
|
||||||
self.do_sample_t=None
|
|
||||||
|
|
||||||
def _standardize(self, values, default):
|
|
||||||
if isinstance(values, list):
|
|
||||||
values=values.copy()
|
|
||||||
else:
|
|
||||||
values=[values]*self.batch_size
|
|
||||||
assert len(values)==self.batch_size
|
|
||||||
for i, v in enumerate(values):
|
|
||||||
if v is None:
|
|
||||||
values[i]=default
|
|
||||||
return values
|
|
||||||
|
|
||||||
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
|
|
||||||
last_token_scores=scores[:, -1, :]
|
|
||||||
|
|
||||||
if self.repetition_penalty_t is not None:
|
|
||||||
score = torch.gather(last_token_scores, 1, input_ids)
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
|
||||||
score = torch.where(score < 0, score * self.repetition_penalty_t, score / self.repetition_penalty_t)
|
|
||||||
last_token_scores.scatter_(1, input_ids, score)
|
|
||||||
|
|
||||||
if self.temperature_t is not None:
|
|
||||||
last_token_scores.div_(self.temperature_t)
|
|
||||||
|
|
||||||
if self.top_k_t is not None:
|
|
||||||
if last_token_scores.size(-1)>self.max_top_k: # Safety check
|
|
||||||
max_top_k=last_token_scores.size(-1)
|
|
||||||
top_k=torch.clamp_max(self.top_k_t,max_top_k) # Run only if needed.
|
|
||||||
else:
|
|
||||||
max_top_k=self.max_top_k
|
|
||||||
top_k=self.top_k_t
|
|
||||||
kth_scores=torch.gather(torch.topk(last_token_scores, max_top_k)[0], 1, top_k)
|
|
||||||
if self.top_k_mask is not None:
|
|
||||||
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
|
||||||
# Remove all tokens with a probability less than the last token of the top-k
|
|
||||||
indices_to_remove = last_token_scores < kth_scores
|
|
||||||
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
|
|
||||||
|
|
||||||
if self.top_p_t is not None:
|
|
||||||
# TODO: Merge wit top_k
|
|
||||||
sorted_logits, sorted_indices = torch.sort(last_token_scores, descending=True)
|
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
||||||
sorted_indices_to_remove = cumulative_probs <= self.top_p_t
|
|
||||||
# scatter sorted tensors to original indexing
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
|
|
||||||
|
|
||||||
if self.typical_p_t is not None:
|
|
||||||
# calculate entropy
|
|
||||||
normalized = torch.nn.functional.log_softmax(last_token_scores, dim=-1)
|
|
||||||
p = torch.exp(normalized)
|
|
||||||
ent = -(normalized * p).nansum(-1, keepdim=True)
|
|
||||||
# shift and sort
|
|
||||||
shifted_scores = torch.abs((-normalized) - ent)
|
|
||||||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
|
||||||
sorted_logits = last_token_scores.gather(-1, sorted_indices)
|
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
||||||
# Remove tokens with cumulative mass above the threshold
|
|
||||||
last_ind = (cumulative_probs < self.typical_p_t).sum(dim=1)
|
|
||||||
last_ind[last_ind < 0] = 0
|
|
||||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
last_token_scores = last_token_scores.masked_fill(indices_to_remove, self.filter_value)
|
|
||||||
|
|
||||||
|
|
||||||
if self.num_do_sample:
|
|
||||||
probs = torch.nn.functional.softmax(last_token_scores, -1)
|
|
||||||
next_token_ids = torch.multinomial(probs, num_samples=1)
|
|
||||||
if self.do_sample_t is not None:
|
|
||||||
next_token_ids=torch.where(self.do_sample_t, next_token_ids, torch.argmax(last_token_scores, dim=-1))
|
|
||||||
else:
|
|
||||||
next_token_ids = torch.argmax(last_token_scores, dim=-1)
|
|
||||||
|
|
||||||
if return_logprobs:
|
|
||||||
# Compute logprobs
|
|
||||||
if scores.size(1)==1:
|
|
||||||
scores=last_token_scores.unsqueeze(1)
|
|
||||||
else:
|
|
||||||
# TODO: Post-process all the tokens?
|
|
||||||
scores[:, -1, :]=last_token_scores
|
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, dim=-1)
|
|
||||||
else:
|
|
||||||
logprobs=None
|
|
||||||
|
|
||||||
return next_token_ids, logprobs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
|
||||||
device: torch.device,
|
|
||||||
) -> "VectorizedNextTokenChooser":
|
|
||||||
# TODO: Seeds are ignored
|
|
||||||
return VectorizedNextTokenChooser(
|
|
||||||
batch_size=len(pb),
|
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
|
||||||
temperature=[pb_.temperature for pb_ in pb],
|
|
||||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
|
||||||
top_k=[pb_.top_k for pb_ in pb],
|
|
||||||
top_p=[pb_.top_p for pb_ in pb],
|
|
||||||
typical_p=[pb_.typical_p for pb_ in pb],
|
|
||||||
do_sample=[pb_.do_sample for pb_ in pb],
|
|
||||||
seeds=[pb_.seed for pb_ in pb],
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter(self, keep_indices: List[int]) -> "VectorizedNextTokenChooser":
|
|
||||||
return VectorizedNextTokenChooser(
|
|
||||||
batch_size=len(keep_indices),
|
|
||||||
watermark=[self.watermark[i] for i in keep_indices],
|
|
||||||
temperature=[self.temperature[i] for i in keep_indices],
|
|
||||||
repetition_penalty=[self.repetition_penalty[i] for i in keep_indices],
|
|
||||||
top_k=[self.top_k[i] for i in keep_indices],
|
|
||||||
top_p=[self.top_p[i] for i in keep_indices],
|
|
||||||
typical_p=[self.typical_p[i] for i in keep_indices],
|
|
||||||
do_sample=[self.do_sample[i] for i in keep_indices],
|
|
||||||
seeds=[self.seeds[i] for i in keep_indices],
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def concatenate(cls, next_token_choosers: List["VectorizedNextTokenChooser"]) -> "VectorizedNextTokenChooser":
|
|
||||||
return cls(
|
|
||||||
batch_size=sum(next_token_chooser.batch_size for next_token_chooser in next_token_choosers),
|
|
||||||
watermark=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.watermark],
|
|
||||||
temperature=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.temperature],
|
|
||||||
repetition_penalty=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.repetition_penalty],
|
|
||||||
top_k=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_k],
|
|
||||||
top_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.top_p],
|
|
||||||
typical_p=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.typical_p],
|
|
||||||
do_sample=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.do_sample],
|
|
||||||
seeds=[x for next_token_chooser in next_token_choosers for x in next_token_chooser.seeds],
|
|
||||||
device=next_token_choosers[0].device,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorizedCausalLM(Model):
|
class VectorizedCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -633,7 +424,7 @@ class VectorizedCausalLM(Model):
|
|||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
batch.requests[i].id,
|
batch.requests[i].id,
|
||||||
prefill_tokens[i] if batch.details and query_length>1 else None, # TODO: Prefill tokens
|
prefill_tokens[i] if batch.details and query_length>1 else None,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
token_logprobs[i] if batch.details else 0.0,
|
token_logprobs[i] if batch.details else 0.0,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
|
@ -32,7 +32,7 @@ class Sampling:
|
|||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return logits.argmax()
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
|
302
server/text_generation_server/utils/tokens_heterogeneous.py
Normal file
302
server/text_generation_server/utils/tokens_heterogeneous.py
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
LogitsWarper,
|
||||||
|
LogitsProcessor,
|
||||||
|
LogitsProcessorList,
|
||||||
|
)
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils.tokens import Greedy, Sampling
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""
|
||||||
|
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repetition_penalty (`List[float]`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, penalty: List[float], device:torch.device):
|
||||||
|
self.penalty = torch.tensor(penalty, dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
score = torch.gather(scores, 1, input_ids)
|
||||||
|
|
||||||
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
|
|
||||||
|
scores.scatter_(1, input_ids, score)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
class HeterogeneousTemperatureLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: List[float], device:torch.device):
|
||||||
|
self.temperature = torch.tensor(temperature, dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
scores.div_(self.temperature)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||||
|
"""
|
||||||
|
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_p: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
||||||
|
self.top_p = torch.tensor(top_p, dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep
|
||||||
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||||
|
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_k: List[int], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
||||||
|
self.max_top_k = max(top_k)
|
||||||
|
self.top_k = torch.tensor([max(x - 1, min_tokens_to_keep-1) for x in top_k], dtype=torch.int64,device=device).unsqueeze(1)
|
||||||
|
zeros=[x == 0 for x in top_k]
|
||||||
|
if any(zeros):
|
||||||
|
self.top_k_mask = torch.tensor(zeros, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
self.top_k_mask = None
|
||||||
|
self.filter_value = filter_value
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
if scores.size(-1)>self.max_top_k: # Safety check
|
||||||
|
max_top_k=scores.size(-1)
|
||||||
|
top_k=torch.clamp_max(self.top_k,max_top_k) # Run only if needed.
|
||||||
|
else:
|
||||||
|
max_top_k=self.max_top_k
|
||||||
|
top_k=self.top_k
|
||||||
|
kth_scores=torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
|
||||||
|
if self.top_k_mask is not None:
|
||||||
|
kth_scores.masked_fill_(self.top_k_mask, self.filter_value)
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = scores < kth_scores
|
||||||
|
scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
|
||||||
|
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
|
This version allows for a separate value for each sample and runs inplace when possible.
|
||||||
|
It doesn't validate inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mass (`float`):
|
||||||
|
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mass: List[float], device:torch.device, filter_value: float = -math.inf, min_tokens_to_keep: int = 1):
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.mass = torch.tensor(mass, dtype=torch.float32, device=device).unsqueeze(1)
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
# calculate entropy
|
||||||
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
|
p = torch.exp(normalized)
|
||||||
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
|
||||||
|
# shift and sort
|
||||||
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||||
|
sorted_logits = scores.gather(-1, sorted_indices)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative mass above the threshold
|
||||||
|
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||||
|
last_ind[last_ind < 0] = 0
|
||||||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||||
|
if self.min_tokens_to_keep > 1:
|
||||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
|
||||||
|
scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousSampling:
|
||||||
|
r"""
|
||||||
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
||||||
|
"""
|
||||||
|
def __init__(self, do_sample:List[bool], seeds: List[int], device: torch.device):
|
||||||
|
self.seeds = seeds
|
||||||
|
self.greedy=Greedy()
|
||||||
|
# TODO: Most seeds are ignored
|
||||||
|
self.sampling=Sampling(seeds[0], device)
|
||||||
|
self.do_sample=torch.tensor(do_sample, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
return torch.where(self.do_sample, self.sampling(logits), self.greedy(logits))
|
||||||
|
|
||||||
|
class HeterogeneousNextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
batch_size:int,
|
||||||
|
device:torch.device,
|
||||||
|
watermark:Optional[Union[bool,List[Optional[bool]]]]=None,
|
||||||
|
temperature:Optional[Union[float,List[Optional[float]]]]=None,
|
||||||
|
repetition_penalty:Optional[Union[float,List[Optional[float]]]]=None,
|
||||||
|
top_k:Optional[Union[int,List[Optional[int]]]]=None,
|
||||||
|
top_p:Optional[Union[float,List[Optional[float]]]]=None,
|
||||||
|
typical_p:Optional[Union[float,List[Optional[float]]]]=None,
|
||||||
|
do_sample:Optional[Union[bool,List[Optional[bool]]]]=None,
|
||||||
|
seeds:Optional[Union[int,List[Optional[int]]]]=None,
|
||||||
|
):
|
||||||
|
# TODO: Most seeds are ignored
|
||||||
|
seeds=self._standardize(seeds, batch_size, 0)
|
||||||
|
do_sample=self._standardize(do_sample, batch_size, False)
|
||||||
|
|
||||||
|
warpers = LogitsProcessorList()
|
||||||
|
|
||||||
|
watermark=self._standardize(watermark, batch_size, False)
|
||||||
|
if any(watermark):
|
||||||
|
raise NotImplementedError("Watermarking not implemented")
|
||||||
|
|
||||||
|
repetition_penalty=self._standardize(repetition_penalty, batch_size, 1.0)
|
||||||
|
if any([x!=1.0 for x in repetition_penalty]):
|
||||||
|
warpers.append(HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, device))
|
||||||
|
|
||||||
|
temperature=self._standardize(temperature, batch_size, 1.0)
|
||||||
|
if any([x!=1.0 for x in temperature]):
|
||||||
|
do_sample=[sample or x!=1.0 for x, sample in zip(temperature, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, device))
|
||||||
|
|
||||||
|
top_k=self._standardize(top_k, batch_size, 0)
|
||||||
|
n_top_k=sum([x!=0 for x in top_k])
|
||||||
|
if n_top_k>0:
|
||||||
|
do_sample=[sample or x!=0 for x, sample in zip(top_k, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
||||||
|
|
||||||
|
top_p=self._standardize(top_p, batch_size, 1.0)
|
||||||
|
if any([x<1.0 for x in top_p]):
|
||||||
|
do_sample=[sample or x<1.0 for x, sample in zip(top_p, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, device))
|
||||||
|
|
||||||
|
typical_p=self._standardize(typical_p, batch_size, 1.0)
|
||||||
|
if any([x<1.0 for x in typical_p]):
|
||||||
|
do_sample=[sample or x<1.0 for x, sample in zip(typical_p, do_sample)]
|
||||||
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, device))
|
||||||
|
|
||||||
|
self.warpers=warpers
|
||||||
|
|
||||||
|
num_do_sample=sum(do_sample)
|
||||||
|
if num_do_sample==0:
|
||||||
|
self.choice=Greedy()
|
||||||
|
elif num_do_sample<batch_size:
|
||||||
|
self.choice=HeterogeneousSampling(do_sample, seeds, device)
|
||||||
|
else:
|
||||||
|
# TODO: Most seeds are ignored
|
||||||
|
self.choice=Sampling(seeds[0], device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _standardize(values, batch_size, default):
|
||||||
|
if isinstance(values, list):
|
||||||
|
values=values.copy()
|
||||||
|
else:
|
||||||
|
values=[values]*batch_size
|
||||||
|
assert len(values)==batch_size
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v is None:
|
||||||
|
values[i]=default
|
||||||
|
return values
|
||||||
|
|
||||||
|
def __call__(self, input_ids:torch.Tensor, scores:torch.Tensor, return_logprobs:bool):
|
||||||
|
last_token_scores=self.warpers(input_ids, scores[:, -1, :])
|
||||||
|
next_token_ids=self.choice(last_token_scores)
|
||||||
|
|
||||||
|
if return_logprobs:
|
||||||
|
# Compute logprobs
|
||||||
|
if scores.size(1)==1:
|
||||||
|
scores=last_token_scores.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# TODO: Post-process all the tokens?
|
||||||
|
scores[:, -1, :]=last_token_scores
|
||||||
|
logprobs = torch.log_softmax(scores, dim=-1)
|
||||||
|
else:
|
||||||
|
logprobs=None
|
||||||
|
|
||||||
|
return next_token_ids, logprobs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
|
device: torch.device,
|
||||||
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
|
# TODO: Seeds are ignored
|
||||||
|
return HeterogeneousNextTokenChooser(
|
||||||
|
batch_size=len(pb),
|
||||||
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
|
temperature=[pb_.temperature for pb_ in pb],
|
||||||
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||||
|
top_k=[pb_.top_k for pb_ in pb],
|
||||||
|
top_p=[pb_.top_p for pb_ in pb],
|
||||||
|
typical_p=[pb_.typical_p for pb_ in pb],
|
||||||
|
do_sample=[pb_.do_sample for pb_ in pb],
|
||||||
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
|
device=device,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user