fix: adjust the NextTokenChooser logit bias processor

This commit is contained in:
drbh 2025-04-22 16:36:34 +00:00
parent da3f18e5c8
commit e44703d542
2 changed files with 53 additions and 0 deletions

View File

@ -625,6 +625,53 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
return self
class LogitBiasProcessor:
"""Process logits with logit biases."""
def __init__(
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
):
self.tokenizer = tokenizer
self.logit_biases = logit_biases or {}
# Pre-compute token IDs for each token string
self.token_id_mapping = {}
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# If no logit biases, return scores unchanged
if not self.logit_biases:
return scores
# Apply bias to the corresponding scores
for token_str, bias_value in self.logit_biases.items():
# Get token ID, either from cache or by computing it
if token_str not in self.token_id_mapping:
if token_str.isdigit():
# If the token string is already a numeric ID
token_id = int(token_str)
else:
# Otherwise, use the tokenizer to get the ID
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
token_id = tokens[0] if tokens else -1 # Use -1 for not found
self.token_id_mapping[token_str] = token_id
token_id = self.token_id_mapping[token_str]
# Apply bias if token ID is valid
if 0 <= token_id < scores.size(-1):
scores[:, token_id] += bias_value
return scores
def filter(self, indices):
"""Keep only the logit biases for the specified indices."""
new_logit_biases = {
k: self.logit_biases[k] for k in indices if k in self.logit_biases
}
return LogitBiasProcessor(new_logit_biases, self.tokenizer)
class HeterogeneousLogitBiasProcessor:
"""Process logits with different logit biases for each sequence in the batch."""

View File

@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
LogitBiasProcessor,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
@ -59,6 +60,11 @@ class NextTokenChooser:
if grammar != ""
else None
)
self.logit_bias_processor = (
LogitBiasProcessor(logit_bias, tokenizer, device)
if logit_bias is not None and len(logit_bias) > 0
else None
)
self.tokenizer = tokenizer
self.logit_bias = logit_bias