From 9eeccbf9a5273f23f66c22ca9c9d9042f7ff77d5 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 28 Apr 2025 13:44:37 +0000 Subject: [PATCH] fix: improve processor logic and refactor --- router/src/server.rs | 1 - router/src/validation.rs | 59 ++++++-- .../utils/logits_process.py | 138 ++++++++---------- server/text_generation_server/utils/tokens.py | 13 +- 4 files changed, 110 insertions(+), 101 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 4f9fdc87..42bfeb7c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1192,7 +1192,6 @@ pub(crate) async fn chat_completions( let (generate_request, using_tools): (GenerateRequest, bool) = chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); - println!("ChatRequest: {:#?}", generate_request); let logprobs = logprobs.unwrap_or_default(); // extract model id from request if specified diff --git a/router/src/validation.rs b/router/src/validation.rs index 6379b145..bda19224 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -34,6 +34,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, disable_grammar_support: bool, + vocab_size: u32, /// Channel to communicate with the background tokenization task sender: mpsc::UnboundedSender, } @@ -88,6 +89,19 @@ impl Validation { validation_sender }; + let vocab_size = match &tokenizer { + Tokenizer::Python { tokenizer_name, .. } => { + warn!( + "Tokenizer {} is not supported for validation", + tokenizer_name + ); + 0 + } + Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false), + } + .try_into() + .unwrap_or(0); + Self { max_best_of, sender, @@ -96,6 +110,7 @@ impl Validation { max_input_length, max_total_tokens, disable_grammar_support, + vocab_size, } } @@ -409,6 +424,35 @@ impl Validation { None => None, }; + let logit_bias = match &request.parameters.logit_bias { + Some(bias) if !bias.is_empty() => { + for (token_str, _) in bias.iter() { + let token_id = token_str.parse::().map_err(|_| { + ValidationError::LogitBiasInvalid(format!( + "Token ID {} is not a valid number.", + token_str + )) + })?; + + if token_id >= self.vocab_size { + return Err(ValidationError::LogitBiasInvalid(format!( + "Token ID {} is out of range. Must be between 0 and {}.", + token_id, + self.vocab_size - 1 + ))); + } + } + + // Transform into the required format + Some( + bias.iter() + .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) + .collect(), + ) + } + _ => None, + }; + let parameters = ValidParameters { temperature, repetition_penalty, @@ -420,18 +464,7 @@ impl Validation { seed, watermark, grammar, - logit_bias: Some( - request - .parameters - .logit_bias - .iter() - .flat_map(|bias| { - bias.iter() - .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) - .collect::>() - }) - .collect(), - ), + logit_bias, }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, @@ -1011,6 +1044,8 @@ pub enum ValidationError { FailedFetchImage(#[from] reqwest::Error), #[error("{0} modality is not supported")] UnsupportedModality(&'static str), + #[error("logit_bias is not valid: {0}")] + LogitBiasInvalid(String), } #[cfg(test)] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9f14b411..ad769990 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -625,55 +625,49 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): return self -class LogitBiasProcessor: - """Process logits with logit biases.""" +class LogitBiasProcessor(LogitsProcessor): + """ + `LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their + corresponding bias values. Bias are applied to the logits during each forward pass. + + Supports token IDs provided as strings (e.g., {"9707": -100}). + """ def __init__( - self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase + self, + logit_biases: dict, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ): - self.tokenizer = tokenizer - self.logit_biases = logit_biases or {} + assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases" - # Pre-compute token IDs for each token string - self.token_id_mapping = {} + vocab_size = len(tokenizer) + + # Convert keys to integers and values to a list + token_ids = torch.tensor( + [int(k) for k in logit_biases.keys()], dtype=torch.long + ) + bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float) + + # Create a tensor and directly copy bias values at the corresponding indices + self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float) + self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # If no logit biases, return scores unchanged - if not self.logit_biases: - return scores - - # Apply bias to the corresponding scores - for token_str, bias_value in self.logit_biases.items(): - # Get token ID, either from cache or by computing it - if token_str not in self.token_id_mapping: - if token_str.isdigit(): - # If the token string is already a numeric ID - token_id = int(token_str) - else: - # Otherwise, use the tokenizer to get the ID - tokens = self.tokenizer.encode(token_str, add_special_tokens=False) - token_id = tokens[0] if tokens else -1 # Use -1 for not found - - self.token_id_mapping[token_str] = token_id - - token_id = self.token_id_mapping[token_str] - - # Apply bias if token ID is valid - if 0 <= token_id < scores.size(-1): - scores[:, token_id] += bias_value - + # Apply bias tensor as a broadcasted addition + if self.bias_tensor.shape[0] != scores.shape[1]: + # Fix if the bias tensor is smaller than the scores + self.bias_tensor = torch.nn.functional.pad( + self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0]) + ) + scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype)) return scores - def filter(self, indices): - """Keep only the logit biases for the specified indices.""" - new_logit_biases = { - k: self.logit_biases[k] for k in indices if k in self.logit_biases - } - return LogitBiasProcessor(new_logit_biases, self.tokenizer) - -class HeterogeneousLogitBiasProcessor: - """Process logits with different logit biases for each sequence in the batch.""" +class HeterogeneousLogitBiasProcessor(LogitsProcessor): + """ + Process logits with different logit biases for each sequence in the batch. + """ def __init__( self, @@ -681,54 +675,42 @@ class HeterogeneousLogitBiasProcessor: tokenizer: PreTrainedTokenizerBase, device: torch.device, ): - self.device = device self.tokenizer = tokenizer self.logit_biases = logit_biases - self.batch_size = len(logit_biases) + # import ipdb; ipdb.set_trace() + self.vocab_size = len(tokenizer) - # Pre-compute token IDs for each token string - self.token_id_mapping = {} + # Create batch_size x vocab_size bias matrix + self.bias_matrix = torch.zeros( + (len(logit_biases), self.vocab_size), dtype=torch.float, device=device + ) - # Create a mapping of indices that have logit biases - self.indices_with_biases = { - i: bias_dict - for i, bias_dict in enumerate(self.logit_biases) - if bias_dict is not None and len(bias_dict) > 0 - } + # for each logit bias dictionary, convert keys to integers and values to a list + for i, logit_bias in enumerate(logit_biases): + token_ids = torch.tensor( + [int(k) for k in logit_bias.keys()], dtype=torch.long + ).to(device=device) + bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to( + device=device + ) + # Create a tensor and directly copy bias values at the corresponding indices + self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # If no indices have biases, return scores unchanged - if not self.indices_with_biases: - return scores - - # For each index with a bias, apply the bias to the corresponding scores - for i, bias_dict in self.indices_with_biases.items(): - for token_str, bias_value in bias_dict.items(): - # Get token ID, either from cache or by computing it - if token_str not in self.token_id_mapping: - if token_str.isdigit(): - # If the token string is already a numeric ID - token_id = int(token_str) - else: - # Otherwise, use the tokenizer to get the ID - tokens = self.tokenizer.encode( - token_str, add_special_tokens=False - ) - token_id = tokens[0] if tokens else -1 # Use -1 for not found - - self.token_id_mapping[token_str] = token_id - - token_id = self.token_id_mapping[token_str] - - # Apply bias if token ID is valid - if 0 <= token_id < scores.size(-1): - scores[i, token_id] += bias_value + # Apply bias matrix as a broadcasted addition + if self.bias_matrix.shape[1] != scores.shape[1]: + # Fix if the bias matrix is smaller than the scores + self.bias_matrix = torch.nn.functional.pad( + self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) + ) + scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype)) return scores - def filter(self, indices: List[int]): - """Keep only the logit biases for the specified indices.""" + def filter(self, indices): new_logit_biases = [self.logit_biases[i] for i in indices] + if not any(bias and len(bias) > 0 for bias in new_logit_biases): + return None return HeterogeneousLogitBiasProcessor( new_logit_biases, self.tokenizer, self.device ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index eeca7273..fa982c30 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -66,7 +66,6 @@ class NextTokenChooser: else None ) self.tokenizer = tokenizer - self.logit_bias = logit_bias has_warpers = ( (temperature is not None and temperature != 1.0) @@ -136,7 +135,7 @@ class NextTokenChooser: tokenizer=tokenizer, grammar=pb.grammar, grammar_type=pb.grammar_type, - logit_bias=dict(pb.logit_bias) if pb.logit_bias else None, + logit_bias=pb.logit_bias, ) @@ -264,10 +263,6 @@ class HeterogeneousNextTokenChooser: ): warpers = [] - # Initialize with empty logit biases if none provided - if logit_biases is None: - logit_biases = [None] * len(do_sample) - self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -306,7 +301,7 @@ class HeterogeneousNextTokenChooser: self.logit_bias_processor = ( HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device) - if any([bias is not None and len(bias) > 0 for bias in logit_biases]) + if any([logit_bias is not None for logit_bias in logit_biases]) else None ) @@ -530,9 +525,7 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), - logit_biases=[ - dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb - ], + logit_biases=[pb_.logit_bias for pb_ in pb], )