mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
647 lines
23 KiB
Python
647 lines
23 KiB
Python
import re
|
|
from typing import List, Optional, Tuple, Set, Union
|
|
|
|
import math
|
|
import torch
|
|
from text_generation_server.pb import generate_pb2
|
|
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
|
|
from text_generation_server.utils.logits_process import (
|
|
FrequencyPenaltyLogitsProcessor,
|
|
GrammarLogitProcessor,
|
|
HeterogeneousProcessorWrapper,
|
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
|
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
|
HeterogeneousTemperatureLogitsWarper,
|
|
HeterogeneousTopKLogitsWarper,
|
|
HeterogeneousTopPLogitsWarper,
|
|
HeterogeneousTypicalLogitsWarper,
|
|
HeterogeneousGrammarLogitProcessor,
|
|
static_warper,
|
|
)
|
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
|
|
|
|
|
class NextTokenChooser:
|
|
def __init__(
|
|
self,
|
|
watermark: bool = False,
|
|
temperature: float = 1.0,
|
|
repetition_penalty: float = 1.0,
|
|
frequency_penalty: float = 0.0,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
typical_p: Optional[float] = None,
|
|
do_sample: bool = False,
|
|
seed: int = 0,
|
|
device: str = "cpu",
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
|
grammar: str = "",
|
|
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
|
fsm_grammar_state: int = 0,
|
|
):
|
|
self.watermark_processor = (
|
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
|
)
|
|
self.repetition_processor = (
|
|
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
|
if repetition_penalty and repetition_penalty != 1.0
|
|
else None
|
|
)
|
|
self.frequency_processor = (
|
|
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
|
|
if frequency_penalty and frequency_penalty != 0.0
|
|
else None
|
|
)
|
|
self.grammar_processor = (
|
|
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
|
|
if grammar != ""
|
|
else None
|
|
)
|
|
self.tokenizer = tokenizer
|
|
|
|
has_warpers = (
|
|
(temperature is not None and temperature != 1.0)
|
|
or (top_k is not None and top_k != 0)
|
|
or (top_p is not None and top_p < 1.0)
|
|
or (typical_p is not None and typical_p < 1.0)
|
|
)
|
|
if has_warpers:
|
|
self.static_warper = static_warper(
|
|
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
|
)
|
|
else:
|
|
self.static_warper = None
|
|
|
|
sampling = do_sample or has_warpers
|
|
|
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
|
self.fsm_grammar_state = fsm_grammar_state
|
|
self.grammar = grammar
|
|
|
|
def __call__(self, input_ids, scores):
|
|
if self.watermark_processor is not None:
|
|
scores = self.watermark_processor(input_ids, scores)
|
|
if self.repetition_processor is not None:
|
|
scores = self.repetition_processor(input_ids, scores)
|
|
if self.frequency_processor is not None:
|
|
scores = self.frequency_processor(input_ids, scores)
|
|
if self.grammar_processor is not None:
|
|
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
|
|
|
if self.static_warper is None:
|
|
next_logprob = torch.log_softmax(scores, -1)
|
|
else:
|
|
scores, next_logprob = self.static_warper(scores)
|
|
|
|
next_id = self.choice(scores[-1]).view(1, 1)
|
|
|
|
return next_id, next_logprob
|
|
|
|
def advance_grammar(self, next_id: int):
|
|
if self.grammar_processor is not None:
|
|
self.fsm_grammar_state = self.grammar_processor.advance(
|
|
next_id, self.fsm_grammar_state
|
|
)
|
|
return self
|
|
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.NextTokenChooserParameters,
|
|
device: torch.device,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
) -> "NextTokenChooser":
|
|
return NextTokenChooser(
|
|
watermark=pb.watermark,
|
|
temperature=pb.temperature,
|
|
repetition_penalty=pb.repetition_penalty,
|
|
frequency_penalty=pb.frequency_penalty,
|
|
top_k=pb.top_k,
|
|
top_p=pb.top_p,
|
|
typical_p=pb.typical_p,
|
|
do_sample=pb.do_sample,
|
|
seed=pb.seed,
|
|
device=device,
|
|
tokenizer=tokenizer,
|
|
grammar=pb.grammar,
|
|
grammar_type=pb.grammar_type,
|
|
)
|
|
|
|
|
|
class StopSequenceCriteria:
|
|
def __init__(self, stop_sequence: str):
|
|
stop_sequence = re.escape(stop_sequence)
|
|
self.regex = re.compile(f"{stop_sequence}$")
|
|
|
|
def __call__(self, output: str) -> bool:
|
|
if self.regex.findall(output):
|
|
return True
|
|
return False
|
|
|
|
|
|
class StoppingCriteria:
|
|
def __init__(
|
|
self,
|
|
eos_token_ids: Optional[Union[Set[int], int]],
|
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
|
max_new_tokens: int = 20,
|
|
ignore_eos_token: bool = False,
|
|
):
|
|
if eos_token_ids is None:
|
|
eos_token_ids = set()
|
|
elif isinstance(eos_token_ids, int):
|
|
eos_token_ids = set([eos_token_ids])
|
|
elif isinstance(eos_token_ids, set):
|
|
eos_token_ids = eos_token_ids
|
|
else:
|
|
raise RuntimeError(
|
|
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
|
|
)
|
|
self.eos_token_ids = eos_token_ids
|
|
self.stop_sequence_criterias = stop_sequence_criterias
|
|
self.max_new_tokens = max_new_tokens
|
|
self.current_tokens = 0
|
|
self.current_output = ""
|
|
self.ignore_eos_token = ignore_eos_token
|
|
|
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
|
self.current_tokens += 1
|
|
if self.current_tokens >= self.max_new_tokens:
|
|
return True, FinishReason.FINISH_REASON_LENGTH
|
|
|
|
if isinstance(last_token, torch.Tensor):
|
|
last_token = last_token.item()
|
|
|
|
if not self.ignore_eos_token and last_token in self.eos_token_ids:
|
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
|
|
|
if self.stop_sequence_criterias:
|
|
self.current_output += last_output
|
|
# There is no need to keep an output that is too long
|
|
if len(self.current_output) > 300:
|
|
# Slice to -200 to avoid doing it all the time
|
|
self.current_output = self.current_output[-200:]
|
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
|
if stop_sequence_criteria(self.current_output):
|
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
|
|
|
return False, None
|
|
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: generate_pb2.StoppingCriteriaParameters,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
) -> "StoppingCriteria":
|
|
stop_sequence_criterias = [
|
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
|
]
|
|
# TODO Hack because eos_token_id cannot be what we want.
|
|
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
|
|
return StoppingCriteria(
|
|
eos_token_id,
|
|
stop_sequence_criterias,
|
|
pb.max_new_tokens,
|
|
pb.ignore_eos_token,
|
|
)
|
|
|
|
|
|
def create_n_gram_speculation(
|
|
input_ids: torch.Tensor,
|
|
next_ids: torch.Tensor,
|
|
accepted_ids: torch.Tensor,
|
|
speculate: int,
|
|
verbose: bool,
|
|
):
|
|
# Very trivial approach, find first match in the string.
|
|
# This is much less refined than actual n-gram but seems to work
|
|
# relatively OK in grounded mode and is by far much faster with
|
|
# much less worst case complexity as everything happens on device.
|
|
B = accepted_ids.shape[0]
|
|
device = input_ids.device
|
|
seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
|
|
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
|
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
|
|
speculate, device=device
|
|
)
|
|
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
|
|
|
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
|
return speculative_ids
|
|
|
|
|
|
class HeterogeneousNextTokenChooser:
|
|
def __init__(
|
|
self,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
watermark: List[bool],
|
|
temperature: List[float],
|
|
repetition_penalty: List[float],
|
|
frequency_penalty: List[float],
|
|
top_k: List[int],
|
|
top_p: List[float],
|
|
typical_p: List[float],
|
|
do_sample: List[bool],
|
|
seeds: List[int],
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
grammars: List[str],
|
|
grammar_types: List[int],
|
|
fsm_grammar_states=List[int],
|
|
):
|
|
warpers = []
|
|
|
|
self.watermark_processor = (
|
|
HeterogeneousProcessorWrapper(
|
|
{
|
|
i: WatermarkLogitsProcessor(device=device)
|
|
for i, do_watermark in enumerate(watermark)
|
|
if do_watermark
|
|
}
|
|
)
|
|
if any(watermark)
|
|
else None
|
|
)
|
|
|
|
self.repetition_processor = (
|
|
HeterogeneousRepetitionPenaltyLogitsProcessor(
|
|
repetition_penalty, dtype, device
|
|
)
|
|
if any([x != 1.0 for x in repetition_penalty])
|
|
else None
|
|
)
|
|
|
|
self.frequency_processor = (
|
|
HeterogeneousFrequencyPenaltyLogitsProcessor(
|
|
frequency_penalty, dtype, device
|
|
)
|
|
if any([x != 0.0 for x in frequency_penalty])
|
|
else None
|
|
)
|
|
|
|
self.grammar_processor = (
|
|
HeterogeneousGrammarLogitProcessor(
|
|
tokenizer, device, grammars, grammar_types
|
|
)
|
|
if any([grammar != "" for grammar in grammars])
|
|
else None
|
|
)
|
|
|
|
if any(x != 1.0 for x in temperature):
|
|
do_sample = [
|
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
|
]
|
|
warpers.append(
|
|
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
|
|
)
|
|
|
|
if any(x != 0 for x in top_k):
|
|
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
|
|
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
|
|
|
|
if any(x < 1.0 for x in top_p):
|
|
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
|
|
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
|
|
|
|
if any(x < 1.0 for x in typical_p):
|
|
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
|
|
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
|
|
|
|
self.warpers = warpers
|
|
|
|
if any(do_sample):
|
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
|
else:
|
|
self.choice = Greedy()
|
|
|
|
self.seeds = seeds
|
|
self.do_sample = do_sample
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.tokenizer = tokenizer
|
|
self.fsm_grammar_states = fsm_grammar_states
|
|
self.grammars = grammars
|
|
self.grammar_types = grammar_types
|
|
|
|
def __call__(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
scores: torch.Tensor,
|
|
speculate: int,
|
|
speculated_ids: Optional[torch.Tensor] = None,
|
|
speculative_scores: Optional[torch.Tensor] = None,
|
|
verbose=False,
|
|
):
|
|
if speculated_ids is not None:
|
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
|
S = speculated_ids.shape[1] + 1
|
|
scores = scores.view(B, S, -1)
|
|
else:
|
|
B = scores.shape[0]
|
|
S = 1
|
|
scores = scores.view(B, S, -1)
|
|
|
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
|
|
|
for j in range(S):
|
|
_scores = scores[:, j]
|
|
if self.watermark_processor is not None:
|
|
_scores = self.watermark_processor(input_ids, _scores)
|
|
if self.repetition_processor is not None:
|
|
_scores = self.repetition_processor(input_ids, _scores)
|
|
if self.frequency_processor is not None:
|
|
_scores = self.frequency_processor(input_ids, _scores)
|
|
if self.grammar_processor is not None:
|
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
|
for warper in self.warpers:
|
|
_scores = warper(input_ids, _scores)
|
|
_next_ids = self.choice(_scores)
|
|
scores[:, j] = _scores
|
|
next_ids[:, j] = _next_ids
|
|
next_ids = next_ids.view(B * S)
|
|
allscores = scores.view(B * S, -1)
|
|
alllogprobs = torch.log_softmax(allscores, -1)
|
|
|
|
if speculated_ids is not None:
|
|
accepted_ids = []
|
|
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
|
S = speculated_ids.shape[1] + 1
|
|
indices = []
|
|
for i in range(B):
|
|
_next_ids = next_ids[i * S : (i + 1) * S]
|
|
_speculated_ids = speculated_ids[i]
|
|
validate_speculative = _next_ids[:-1] == _speculated_ids
|
|
index = i * S
|
|
accepted = 1
|
|
# First is always valid
|
|
indices.append(index)
|
|
for valid in validate_speculative.tolist():
|
|
if valid:
|
|
index += 1
|
|
accepted += 1
|
|
indices.append(index)
|
|
else:
|
|
break
|
|
accepted_ids.append(accepted)
|
|
|
|
accepted_ids = torch.tensor(
|
|
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
|
)
|
|
next_ids = next_ids[indices]
|
|
logprobs = alllogprobs[indices]
|
|
indices = torch.arange(B, device=input_ids.device) * S
|
|
if speculative_scores is not None:
|
|
speculative_scores = speculative_scores[indices + accepted_ids - 1]
|
|
else:
|
|
accepted_ids = torch.ones_like(next_ids)
|
|
logprobs = alllogprobs
|
|
|
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
|
|
|
if speculate > 0:
|
|
if speculative_scores is not None:
|
|
# Medusa provided some scores
|
|
speculative_ids = Greedy()(speculative_scores)
|
|
else:
|
|
# n-gram
|
|
speculative_ids = create_n_gram_speculation(
|
|
input_ids, next_ids, accepted_ids, speculate, verbose
|
|
)
|
|
else:
|
|
speculative_ids = None
|
|
|
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
|
|
|
def advance_grammar(self, next_ids: List[int]):
|
|
if self.grammar_processor is not None:
|
|
other_new_states = self.grammar_processor.advance_batch(
|
|
next_ids, self.fsm_grammar_states
|
|
)
|
|
self.fsm_grammar_states = other_new_states
|
|
return self
|
|
|
|
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
|
if self.grammar_processor is not None:
|
|
self.fsm_grammar_states[grammar_state_index] = (
|
|
self.grammar_processor.advance_at_index(
|
|
next_id,
|
|
self.fsm_grammar_states[grammar_state_index],
|
|
grammar_state_index,
|
|
)
|
|
)
|
|
return self
|
|
|
|
def filter(self, indices):
|
|
if self.watermark_processor is not None:
|
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
|
|
|
if self.repetition_processor is not None:
|
|
self.repetition_processor = self.repetition_processor.filter(indices)
|
|
|
|
if self.frequency_processor is not None:
|
|
self.frequency_processor = self.frequency_processor.filter(indices)
|
|
|
|
if self.grammar_processor is not None:
|
|
self.grammar_processor = self.grammar_processor.filter(indices)
|
|
|
|
filtered_warpers = []
|
|
for warper in self.warpers:
|
|
filtered_warper = warper.filter(indices)
|
|
if filtered_warper is not None:
|
|
filtered_warpers.append(filtered_warper)
|
|
self.warpers = filtered_warpers
|
|
|
|
self.seeds = [self.seeds[i] for i in indices]
|
|
self.do_sample = [self.do_sample[i] for i in indices]
|
|
|
|
new_grammars = []
|
|
new_fsm_grammar_states = []
|
|
new_grammar_types = []
|
|
for i in indices:
|
|
new_grammars.append(self.grammars[i])
|
|
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
|
new_grammar_types.append(self.grammar_types[i])
|
|
|
|
self.grammars = new_grammars
|
|
self.fsm_grammar_states = new_fsm_grammar_states
|
|
self.grammar_types = new_grammar_types
|
|
|
|
if any(self.do_sample):
|
|
self.choice.filter(indices)
|
|
else:
|
|
self.choice = Greedy()
|
|
|
|
return self
|
|
|
|
@classmethod
|
|
def from_pb(
|
|
cls,
|
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
fsm_grammar_states: Optional[List[int]] = None,
|
|
) -> "HeterogeneousNextTokenChooser":
|
|
return HeterogeneousNextTokenChooser(
|
|
watermark=[pb_.watermark for pb_ in pb],
|
|
temperature=[pb_.temperature for pb_ in pb],
|
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
|
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
|
|
top_k=[pb_.top_k for pb_ in pb],
|
|
top_p=[pb_.top_p for pb_ in pb],
|
|
typical_p=[pb_.typical_p for pb_ in pb],
|
|
do_sample=[pb_.do_sample for pb_ in pb],
|
|
seeds=[pb_.seed for pb_ in pb],
|
|
device=device,
|
|
dtype=dtype,
|
|
tokenizer=tokenizer,
|
|
grammars=[pb_.grammar for pb_ in pb],
|
|
grammar_types=[pb_.grammar_type for pb_ in pb],
|
|
fsm_grammar_states=(
|
|
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
|
),
|
|
)
|
|
|
|
|
|
class Sampling:
|
|
def __init__(self, seed: int, device: str = "cpu"):
|
|
self.generator = torch.Generator(device)
|
|
self.generator.manual_seed(seed)
|
|
self.seed = seed
|
|
|
|
def __call__(self, logits):
|
|
probs = torch.nn.functional.softmax(logits, -1)
|
|
# Avoid GPU<->CPU sync done by torch multinomial
|
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
|
return probs.div_(q).argmax()
|
|
|
|
|
|
class Greedy:
|
|
def __call__(self, logits):
|
|
return logits.argmax(dim=-1)
|
|
|
|
|
|
class HeterogeneousSampling:
|
|
r"""
|
|
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
|
|
"""
|
|
|
|
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
|
|
self.seeds = seeds
|
|
|
|
self.greedy_indices = []
|
|
self.sampling_mapping = {}
|
|
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
|
|
if sample:
|
|
self.sampling_mapping[i] = Sampling(seed, device)
|
|
else:
|
|
self.greedy_indices.append(i)
|
|
|
|
self.greedy = Greedy()
|
|
|
|
def __call__(self, logits):
|
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
|
if self.greedy_indices:
|
|
# Computing for all indices is faster than slicing
|
|
torch.argmax(logits, -1, out=out)
|
|
|
|
for i, sampling in self.sampling_mapping.items():
|
|
out[i] = sampling(logits[i])
|
|
return out
|
|
|
|
def filter(self, indices):
|
|
new_greedy_indices = []
|
|
new_sampling_mapping = {}
|
|
for i, idx in enumerate(indices):
|
|
if idx in self.sampling_mapping:
|
|
new_sampling_mapping[i] = self.sampling_mapping[idx]
|
|
else:
|
|
new_greedy_indices.append(i)
|
|
|
|
self.greedy_indices = new_greedy_indices
|
|
self.sampling_mapping = new_sampling_mapping
|
|
return self
|
|
|
|
|
|
def batch_top_tokens(
|
|
top_n_tokens: List[int],
|
|
top_n_tokens_tensor: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
accepted_ids: torch.Tensor,
|
|
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
|
|
"""Find the top n most likely tokens for a batch of generations.
|
|
|
|
When multiple tokens have equal probabilities and they don't all fit, the
|
|
remaining tokens are also returned.
|
|
"""
|
|
max_top_n = max(top_n_tokens)
|
|
# Early exit when top_n_tokens is not used
|
|
if max_top_n == 0:
|
|
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
|
|
|
|
batch_size = accepted_ids.shape[0]
|
|
speculate_size = logprobs.shape[0] // batch_size
|
|
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
|
|
# Ensure top_n doesn't exceed vocab size
|
|
top_n_tokens = [
|
|
min(tok, logprobs.size(-1))
|
|
for tok in top_n_tokens
|
|
for _ in range(speculate_size)
|
|
]
|
|
|
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
|
|
|
|
nth_highest = torch.gather(
|
|
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
|
)
|
|
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
|
|
|
# Find the new "fuzzy" top n values
|
|
top_n_indices = (logprobs >= nth_highest).nonzero()
|
|
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
|
|
|
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
|
|
# Take a new topk for these new max n values
|
|
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
|
|
|
|
top_n_ishes = top_n_ishes.tolist()
|
|
top_indices = top_k.indices.tolist()
|
|
top_values = top_k.values.tolist()
|
|
|
|
batch_top_token_ids = []
|
|
batch_top_token_logprobs = []
|
|
accepted_ids_list = accepted_ids.tolist()
|
|
for i, n_accepted_ids in enumerate(accepted_ids_list):
|
|
start = speculate_size * i
|
|
stop = speculate_size * (i + 1)
|
|
_top_indices = top_indices[start:stop]
|
|
_top_values = top_values[start:stop]
|
|
_top_n_ishes = top_n_ishes[start:stop]
|
|
_top_n_tokens = top_n_tokens[start:stop]
|
|
|
|
_top_indices = _top_indices[:n_accepted_ids]
|
|
_top_values = _top_values[:n_accepted_ids]
|
|
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
|
|
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
|
|
|
|
row_top_token_ids = []
|
|
row_top_token_logprobs = []
|
|
|
|
for idxs, vals, n, req_n in zip(
|
|
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
|
|
):
|
|
indices = idxs[:n] if req_n > 0 else []
|
|
values = vals[:n] if req_n > 0 else []
|
|
|
|
row_top_token_ids.append(indices)
|
|
row_top_token_logprobs.append(values)
|
|
|
|
batch_top_token_ids.append(row_top_token_ids)
|
|
batch_top_token_logprobs.append(row_top_token_logprobs)
|
|
|
|
return batch_top_token_ids, batch_top_token_logprobs
|