From e44703d542461e7f41c2363d0409bc2d804ae6ba Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 16:36:34 +0000 Subject: [PATCH] fix: adjust the NextTokenChooser logit bias processor --- .../utils/logits_process.py | 47 +++++++++++++++++++ server/text_generation_server/utils/tokens.py | 6 +++ 2 files changed, 53 insertions(+) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b0dfe571..9f14b411 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -625,6 +625,53 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): return self +class LogitBiasProcessor: + """Process logits with logit biases.""" + + def __init__( + self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase + ): + self.tokenizer = tokenizer + self.logit_biases = logit_biases or {} + + # Pre-compute token IDs for each token string + self.token_id_mapping = {} + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If no logit biases, return scores unchanged + if not self.logit_biases: + return scores + + # Apply bias to the corresponding scores + for token_str, bias_value in self.logit_biases.items(): + # Get token ID, either from cache or by computing it + if token_str not in self.token_id_mapping: + if token_str.isdigit(): + # If the token string is already a numeric ID + token_id = int(token_str) + else: + # Otherwise, use the tokenizer to get the ID + tokens = self.tokenizer.encode(token_str, add_special_tokens=False) + token_id = tokens[0] if tokens else -1 # Use -1 for not found + + self.token_id_mapping[token_str] = token_id + + token_id = self.token_id_mapping[token_str] + + # Apply bias if token ID is valid + if 0 <= token_id < scores.size(-1): + scores[:, token_id] += bias_value + + return scores + + def filter(self, indices): + """Keep only the logit biases for the specified indices.""" + new_logit_biases = { + k: self.logit_biases[k] for k in indices if k in self.logit_biases + } + return LogitBiasProcessor(new_logit_biases, self.tokenizer) + + class HeterogeneousLogitBiasProcessor: """Process logits with different logit biases for each sequence in the batch.""" diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 8c916bfd..eeca7273 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, GrammarLogitProcessor, + LogitBiasProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor, @@ -59,6 +60,11 @@ class NextTokenChooser: if grammar != "" else None ) + self.logit_bias_processor = ( + LogitBiasProcessor(logit_bias, tokenizer, device) + if logit_bias is not None and len(logit_bias) > 0 + else None + ) self.tokenizer = tokenizer self.logit_bias = logit_bias