mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-12 12:22:07 +00:00
fix: adjust the NextTokenChooser logit bias processor
This commit is contained in:
parent
da3f18e5c8
commit
e44703d542
@ -625,6 +625,53 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
return self
|
||||
|
||||
|
||||
class LogitBiasProcessor:
|
||||
"""Process logits with logit biases."""
|
||||
|
||||
def __init__(
|
||||
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_biases = logit_biases or {}
|
||||
|
||||
# Pre-compute token IDs for each token string
|
||||
self.token_id_mapping = {}
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# If no logit biases, return scores unchanged
|
||||
if not self.logit_biases:
|
||||
return scores
|
||||
|
||||
# Apply bias to the corresponding scores
|
||||
for token_str, bias_value in self.logit_biases.items():
|
||||
# Get token ID, either from cache or by computing it
|
||||
if token_str not in self.token_id_mapping:
|
||||
if token_str.isdigit():
|
||||
# If the token string is already a numeric ID
|
||||
token_id = int(token_str)
|
||||
else:
|
||||
# Otherwise, use the tokenizer to get the ID
|
||||
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
||||
|
||||
self.token_id_mapping[token_str] = token_id
|
||||
|
||||
token_id = self.token_id_mapping[token_str]
|
||||
|
||||
# Apply bias if token ID is valid
|
||||
if 0 <= token_id < scores.size(-1):
|
||||
scores[:, token_id] += bias_value
|
||||
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
"""Keep only the logit biases for the specified indices."""
|
||||
new_logit_biases = {
|
||||
k: self.logit_biases[k] for k in indices if k in self.logit_biases
|
||||
}
|
||||
return LogitBiasProcessor(new_logit_biases, self.tokenizer)
|
||||
|
||||
|
||||
class HeterogeneousLogitBiasProcessor:
|
||||
"""Process logits with different logit biases for each sequence in the batch."""
|
||||
|
||||
|
@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
|
||||
from text_generation_server.utils.logits_process import (
|
||||
FrequencyPenaltyLogitsProcessor,
|
||||
GrammarLogitProcessor,
|
||||
LogitBiasProcessor,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||
@ -59,6 +60,11 @@ class NextTokenChooser:
|
||||
if grammar != ""
|
||||
else None
|
||||
)
|
||||
self.logit_bias_processor = (
|
||||
LogitBiasProcessor(logit_bias, tokenizer, device)
|
||||
if logit_bias is not None and len(logit_bias) > 0
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user