From fae510b8f6128bca2b864c5410fca54d2c25d78c Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 16:09:32 +0000 Subject: [PATCH 01/14] feat: support logit bias in chat request --- backends/client/src/v3/client.rs | 2 + backends/client/src/v3/sharded_client.rs | 2 + backends/v3/src/client/grpc_client.rs | 2 + backends/v3/src/client/sharded_client.rs | 2 + backends/v3/src/queue.rs | 9 ++ benchmark/src/lib.rs | 1 + clients/python/text_generation/types.py | 4 +- docs/openapi.json | 21 +++- .../test_logit_bias_baseline.json | 26 +++++ .../test_logit_bias_english_to_spanish.json | 26 +++++ .../test_logit_bias_multiple_tokens.json | 26 +++++ .../test_logit_bias_streaming.json | 20 ++++ .../models/test_flash_logit_bias.py | 109 ++++++++++++++++++ router/src/lib.rs | 18 ++- router/src/server.rs | 2 + router/src/validation.rs | 14 +++ .../utils/logits_process.py | 62 ++++++++++ server/text_generation_server/utils/tokens.py | 27 +++++ 18 files changed, 363 insertions(+), 10 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json create mode 100644 integration-tests/models/test_flash_logit_bias.py diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 968c1f45..b5dd081f 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::collections::HashMap; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -181,6 +182,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index dc3bcdde..99688668 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -5,6 +5,7 @@ use crate::{ClientError, Result}; use crate::v3::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; +use std::collections::HashMap; use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; @@ -244,6 +245,7 @@ impl Health for ShardedClient { watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index f4942f64..9d5e042a 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::collections::HashMap; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -181,6 +182,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 4701c560..b9667dc4 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -10,6 +10,7 @@ use crate::client::{ use crate::client::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; +use std::collections::HashMap; use tonic::transport::Uri; use tracing::instrument; @@ -232,6 +233,7 @@ impl Health for ShardedClient { watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index d3bf4b9c..4c143deb 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -5,6 +5,7 @@ use crate::client::{ }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; +use std::collections::HashMap; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; @@ -522,6 +523,14 @@ impl From for NextTokenChooserParameters { watermark: value.watermark, grammar, grammar_type: grammar_type.into(), + logit_bias: value + .logit_bias + .map(|bias| { + bias.into_iter() + .map(|(token, bias)| (token.to_string(), bias as i32)) + .collect::>() + }) + .unwrap_or_default(), } } } diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index bb4b6a77..cd1f9446 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -47,6 +47,7 @@ pub async fn run( watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: std::collections::HashMap::new(), }; // Initialize terminal properties diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 6f51c153..b9e00fae 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, field_validator, ConfigDict -from typing import Optional, List, Union, Any +from typing import Optional, List, Union, Any, Dict from text_generation.errors import ValidationError @@ -137,7 +137,7 @@ class ChatRequest(BaseModel): # decreasing the model's likelihood to repeat the same line verbatim. frequency_penalty: Optional[float] = None # Bias values for token selection - logit_bias: Optional[List[float]] = None + logit_bias: Optional[Dict[str, int]] = None # Whether to return log probabilities logprobs: Optional[bool] = None # Number of most likely tokens to return at each position diff --git a/docs/openapi.json b/docs/openapi.json index 84ba5885..22145191 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -995,12 +995,12 @@ "nullable": true }, "logit_bias": { - "type": "array", - "items": { - "type": "number", - "format": "float" + "type": "object", + "description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", + "additionalProperties": { + "type": "integer", + "format": "int32" }, - "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", "nullable": true }, "logprobs": { @@ -1589,6 +1589,17 @@ "default": "null", "nullable": true }, + "logit_bias": { + "type": "object", + "description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.", + "default": "null", + "additionalProperties": { + "type": "integer", + "format": "int32" + }, + "example": "{\"1923\": 100, \"1924\": -100}", + "nullable": true + }, "max_new_tokens": { "type": "integer", "format": "int32", diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json new file mode 100644 index 00000000..b683b090 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello! How can I help you today?", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337495, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 21, + "total_tokens": 31 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json new file mode 100644 index 00000000..3c4369f1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "¡Hola! ¿Cómo puedo ayudarte?", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337456, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 21, + "total_tokens": 31 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json new file mode 100644 index 00000000..ea8f3a83 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "Chat!", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337878, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 3, + "prompt_tokens": 25, + "total_tokens": 28 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json new file mode 100644 index 00000000..1668717e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1745337495, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "3.2.3-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_logit_bias.py b/integration-tests/models/test_flash_logit_bias.py new file mode 100644 index 00000000..5ce793b7 --- /dev/null +++ b/integration-tests/models/test_flash_logit_bias.py @@ -0,0 +1,109 @@ +import pytest + + +@pytest.fixture(scope="module") +def logit_bias_model_handle(launcher): + with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def logit_bias_model(logit_bias_model_handle): + await logit_bias_model_handle.health(300) + return logit_bias_model_handle.client + + +@pytest.mark.private +async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot): + """Test that setting negative bias on English tokens forces output to be in Spanish""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=10, + logit_bias={"9707": -100}, # Bias against 'Hello' token + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + ) + assert "¡Hola!" in response.choices[0].message.content + assert "Hello" not in response.choices[0].message.content + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_baseline(logit_bias_model, response_snapshot): + """Test baseline behavior without logit bias for comparison""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=10, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + ) + assert "Hello" in response.choices[0].message.content + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot): + """Test applying bias to multiple tokens simultaneously""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=15, + logit_bias={ + "9707": -100, # Bias against 'Hello' token + "2880": -100, # Bias against 'hi' token + }, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Give me a one-word greeting"}, + ], + }, + ], + ) + assert "Hello" not in response.choices[0].message.content.lower() + assert "hi" not in response.choices[0].message.content.lower() + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_streaming(logit_bias_model, response_snapshot): + """Test logit bias works correctly with streaming enabled""" + responses = await logit_bias_model.chat( + seed=42, + max_tokens=10, + logit_bias={"9707": -100}, # Bias against 'Hello' token + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert "¡Hola!" in generated + assert "Hello" not in generated + assert last_response == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc2..88ab86e9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError}; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; @@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub adapter_id: Option, + + /// Modify the likelihood of specified tokens appearing in the completion. + /// Accepts a hash map that maps token strings to an associated bias value. + #[serde(default)] + #[schema( + nullable = true, + default = "null", + example = "{\"1923\": 100, \"1924\": -100}" + )] + pub logit_bias: Option>, } fn default_parameters() -> GenerateParameters { @@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters { top_n_tokens: None, grammar: None, adapter_id: None, + logit_bias: None, } } - #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[serde(try_from = "PromptDeserializer")] pub struct Prompt(pub Vec); @@ -841,14 +852,13 @@ pub(crate) struct ChatRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, - /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should /// result in a ban or exclusive selection of the relevant token. #[serde(default)] - pub logit_bias: Option>, + pub logit_bias: Option>, /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// output token returned in the content of message. @@ -954,6 +964,7 @@ impl ChatRequest { frequency_penalty, top_p, top_logprobs, + logit_bias, .. } = self; @@ -1029,6 +1040,7 @@ impl ChatRequest { top_n_tokens: top_logprobs, grammar, adapter_id: model.filter(|m| *m != "tgi"), + logit_bias, }, }, using_tools, diff --git a/router/src/server.rs b/router/src/server.rs index 001a85e0..4f9fdc87 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -798,6 +798,7 @@ pub(crate) async fn completions( top_n_tokens: None, grammar: None, adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + logit_bias: None, }, }) .collect(); @@ -1191,6 +1192,7 @@ pub(crate) async fn chat_completions( let (generate_request, using_tools): (GenerateRequest, bool) = chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); + println!("ChatRequest: {:#?}", generate_request); let logprobs = logprobs.unwrap_or_default(); // extract model id from request if specified diff --git a/router/src/validation.rs b/router/src/validation.rs index 28c7f2f8..6379b145 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -420,6 +420,18 @@ impl Validation { seed, watermark, grammar, + logit_bias: Some( + request + .parameters + .logit_bias + .iter() + .flat_map(|bias| { + bias.iter() + .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) + .collect::>() + }) + .collect(), + ), }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, @@ -902,6 +914,8 @@ pub struct ValidParameters { pub watermark: bool, /// / grammar (applied if not empty) pub grammar: Option, + /// / logit bias + pub logit_bias: Option>, } #[derive(Debug, Clone)] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 64a285b9..b0dfe571 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -623,3 +623,65 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): new_fsms.append(self.fsms[i]) self.fsms = new_fsms return self + + +class HeterogeneousLogitBiasProcessor: + """Process logits with different logit biases for each sequence in the batch.""" + + def __init__( + self, + logit_biases: List[Optional[dict]], + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ): + self.device = device + self.tokenizer = tokenizer + self.logit_biases = logit_biases + self.batch_size = len(logit_biases) + + # Pre-compute token IDs for each token string + self.token_id_mapping = {} + + # Create a mapping of indices that have logit biases + self.indices_with_biases = { + i: bias_dict + for i, bias_dict in enumerate(self.logit_biases) + if bias_dict is not None and len(bias_dict) > 0 + } + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If no indices have biases, return scores unchanged + if not self.indices_with_biases: + return scores + + # For each index with a bias, apply the bias to the corresponding scores + for i, bias_dict in self.indices_with_biases.items(): + for token_str, bias_value in bias_dict.items(): + # Get token ID, either from cache or by computing it + if token_str not in self.token_id_mapping: + if token_str.isdigit(): + # If the token string is already a numeric ID + token_id = int(token_str) + else: + # Otherwise, use the tokenizer to get the ID + tokens = self.tokenizer.encode( + token_str, add_special_tokens=False + ) + token_id = tokens[0] if tokens else -1 # Use -1 for not found + + self.token_id_mapping[token_str] = token_id + + token_id = self.token_id_mapping[token_str] + + # Apply bias if token ID is valid + if 0 <= token_id < scores.size(-1): + scores[i, token_id] += bias_value + + return scores + + def filter(self, indices: List[int]): + """Keep only the logit biases for the specified indices.""" + new_logit_biases = [self.logit_biases[i] for i in indices] + return HeterogeneousLogitBiasProcessor( + new_logit_biases, self.tokenizer, self.device + ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665..8c916bfd 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousGrammarLogitProcessor, + HeterogeneousLogitBiasProcessor, static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -38,6 +39,7 @@ class NextTokenChooser: grammar: str = "", grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, fsm_grammar_state: int = 0, + logit_bias: Optional[dict] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -58,6 +60,7 @@ class NextTokenChooser: else None ) self.tokenizer = tokenizer + self.logit_bias = logit_bias has_warpers = ( (temperature is not None and temperature != 1.0) @@ -87,6 +90,8 @@ class NextTokenChooser: scores = self.frequency_processor(input_ids, scores) if self.grammar_processor is not None: scores = self.grammar_processor(scores, self.fsm_grammar_state) + if self.logit_bias_processor is not None: + scores = self.logit_bias_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -125,6 +130,7 @@ class NextTokenChooser: tokenizer=tokenizer, grammar=pb.grammar, grammar_type=pb.grammar_type, + logit_bias=dict(pb.logit_bias) if pb.logit_bias else None, ) @@ -248,9 +254,14 @@ class HeterogeneousNextTokenChooser: grammars: List[str], grammar_types: List[int], fsm_grammar_states=List[int], + logit_biases: List[Optional[dict]] = None, ): warpers = [] + # Initialize with empty logit biases if none provided + if logit_biases is None: + logit_biases = [None] * len(do_sample) + self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -287,6 +298,12 @@ class HeterogeneousNextTokenChooser: else None ) + self.logit_bias_processor = ( + HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device) + if any([bias is not None and len(bias) > 0 for bias in logit_biases]) + else None + ) + if any(x != 1.0 for x in temperature): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -322,6 +339,7 @@ class HeterogeneousNextTokenChooser: self.fsm_grammar_states = fsm_grammar_states self.grammars = grammars self.grammar_types = grammar_types + self.logit_biases = logit_biases def __call__( self, @@ -353,6 +371,8 @@ class HeterogeneousNextTokenChooser: _scores = self.frequency_processor(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + if self.logit_bias_processor is not None: + _scores = self.logit_bias_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) @@ -444,6 +464,9 @@ class HeterogeneousNextTokenChooser: if self.grammar_processor is not None: self.grammar_processor = self.grammar_processor.filter(indices) + if self.logit_bias_processor is not None: + self.logit_bias_processor = self.logit_bias_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) @@ -453,6 +476,7 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + self.logit_biases = [self.logit_biases[i] for i in indices] new_grammars = [] new_fsm_grammar_states = [] @@ -500,6 +524,9 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), + logit_biases=[ + dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb + ], ) From da3f18e5c85336fb90bfc723cdc14e68d59f2398 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 16:26:56 +0000 Subject: [PATCH 02/14] feat: include proto changes --- proto/v3/generate.proto | 2 ++ 1 file changed, 2 insertions(+) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 02980b6f..1018e709 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -104,6 +104,8 @@ message NextTokenChooserParameters { string grammar = 10; /// grammar type GrammarType grammar_type = 11; + /// logit bias dictionary mapping token string to bias value + map logit_bias = 12; } message StoppingCriteriaParameters { From e44703d542461e7f41c2363d0409bc2d804ae6ba Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 16:36:34 +0000 Subject: [PATCH 03/14] fix: adjust the NextTokenChooser logit bias processor --- .../utils/logits_process.py | 47 +++++++++++++++++++ server/text_generation_server/utils/tokens.py | 6 +++ 2 files changed, 53 insertions(+) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b0dfe571..9f14b411 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -625,6 +625,53 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): return self +class LogitBiasProcessor: + """Process logits with logit biases.""" + + def __init__( + self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase + ): + self.tokenizer = tokenizer + self.logit_biases = logit_biases or {} + + # Pre-compute token IDs for each token string + self.token_id_mapping = {} + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If no logit biases, return scores unchanged + if not self.logit_biases: + return scores + + # Apply bias to the corresponding scores + for token_str, bias_value in self.logit_biases.items(): + # Get token ID, either from cache or by computing it + if token_str not in self.token_id_mapping: + if token_str.isdigit(): + # If the token string is already a numeric ID + token_id = int(token_str) + else: + # Otherwise, use the tokenizer to get the ID + tokens = self.tokenizer.encode(token_str, add_special_tokens=False) + token_id = tokens[0] if tokens else -1 # Use -1 for not found + + self.token_id_mapping[token_str] = token_id + + token_id = self.token_id_mapping[token_str] + + # Apply bias if token ID is valid + if 0 <= token_id < scores.size(-1): + scores[:, token_id] += bias_value + + return scores + + def filter(self, indices): + """Keep only the logit biases for the specified indices.""" + new_logit_biases = { + k: self.logit_biases[k] for k in indices if k in self.logit_biases + } + return LogitBiasProcessor(new_logit_biases, self.tokenizer) + + class HeterogeneousLogitBiasProcessor: """Process logits with different logit biases for each sequence in the batch.""" diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 8c916bfd..eeca7273 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, GrammarLogitProcessor, + LogitBiasProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor, @@ -59,6 +60,11 @@ class NextTokenChooser: if grammar != "" else None ) + self.logit_bias_processor = ( + LogitBiasProcessor(logit_bias, tokenizer, device) + if logit_bias is not None and len(logit_bias) > 0 + else None + ) self.tokenizer = tokenizer self.logit_bias = logit_bias From 61a50a81c08cf02aff4d4c7e50a2fd293e68477b Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 16:46:59 +0000 Subject: [PATCH 04/14] fix: include logit_bias in all ValidGenerateRequest's --- backends/v2/src/queue.rs | 2 ++ backends/v3/src/queue.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index c9a9335d..788ecee6 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -12,6 +12,7 @@ use text_generation_router::validation::{ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; +use std::collections::HashMap; /// Queue entry #[derive(Debug)] @@ -429,6 +430,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: None, + logit_bias: HashMap::new(), }, stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 4c143deb..fe8963b8 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -16,6 +16,7 @@ use text_generation_router::validation::{ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; +use std::collections::HashMap; /// Queue entry #[derive(Debug)] @@ -577,6 +578,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: None, + logit_bias: HashMap::new(), }, stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, From 81656bd016ef795441088e13d2724a23fc504c47 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 17:07:28 +0000 Subject: [PATCH 05/14] fix: adjust imports --- backends/v2/src/queue.rs | 2 +- backends/v3/src/queue.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index 788ecee6..b5ec0742 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -12,7 +12,6 @@ use text_generation_router::validation::{ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; -use std::collections::HashMap; /// Queue entry #[derive(Debug)] @@ -402,6 +401,7 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; use std::sync::Arc; use tracing::info_span; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index fe8963b8..6cd2fd76 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -16,7 +16,6 @@ use text_generation_router::validation::{ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; -use std::collections::HashMap; /// Queue entry #[derive(Debug)] From bb5c875f0b1cc16918b85d61422e1e6326dbeed4 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 23 Apr 2025 17:06:22 +0000 Subject: [PATCH 06/14] fix: remove deprecated test and fix typing --- .github/workflows/client-tests.yaml | 26 -------------------------- backends/v2/src/queue.rs | 3 +-- backends/v3/src/queue.rs | 2 +- 3 files changed, 2 insertions(+), 29 deletions(-) delete mode 100644 .github/workflows/client-tests.yaml diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml deleted file mode 100644 index ff2928c4..00000000 --- a/.github/workflows/client-tests.yaml +++ /dev/null @@ -1,26 +0,0 @@ -name: Python Client Tests - -on: - pull_request: - paths: - - ".github/workflows/client-tests.yaml" - - "clients/python/**" - -jobs: - run_tests: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.9 - - name: Install - run: | - cd clients/python && pip install . - - name: Run tests - run: | - pip install pytest pytest-asyncio - export HF_TOKEN=${{ secrets.HF_TOKEN }} - make python-client-tests diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index b5ec0742..0e40fdf2 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -401,7 +401,6 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { use super::*; - use std::collections::HashMap; use std::sync::Arc; use tracing::info_span; @@ -430,7 +429,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: None, - logit_bias: HashMap::new(), + logit_bias: None, }, stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 6cd2fd76..95e02109 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -577,7 +577,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: None, - logit_bias: HashMap::new(), + logit_bias: None, }, stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, From 9eeccbf9a5273f23f66c22ca9c9d9042f7ff77d5 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 28 Apr 2025 13:44:37 +0000 Subject: [PATCH 07/14] fix: improve processor logic and refactor --- router/src/server.rs | 1 - router/src/validation.rs | 59 ++++++-- .../utils/logits_process.py | 138 ++++++++---------- server/text_generation_server/utils/tokens.py | 13 +- 4 files changed, 110 insertions(+), 101 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 4f9fdc87..42bfeb7c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1192,7 +1192,6 @@ pub(crate) async fn chat_completions( let (generate_request, using_tools): (GenerateRequest, bool) = chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); - println!("ChatRequest: {:#?}", generate_request); let logprobs = logprobs.unwrap_or_default(); // extract model id from request if specified diff --git a/router/src/validation.rs b/router/src/validation.rs index 6379b145..bda19224 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -34,6 +34,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, disable_grammar_support: bool, + vocab_size: u32, /// Channel to communicate with the background tokenization task sender: mpsc::UnboundedSender, } @@ -88,6 +89,19 @@ impl Validation { validation_sender }; + let vocab_size = match &tokenizer { + Tokenizer::Python { tokenizer_name, .. } => { + warn!( + "Tokenizer {} is not supported for validation", + tokenizer_name + ); + 0 + } + Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false), + } + .try_into() + .unwrap_or(0); + Self { max_best_of, sender, @@ -96,6 +110,7 @@ impl Validation { max_input_length, max_total_tokens, disable_grammar_support, + vocab_size, } } @@ -409,6 +424,35 @@ impl Validation { None => None, }; + let logit_bias = match &request.parameters.logit_bias { + Some(bias) if !bias.is_empty() => { + for (token_str, _) in bias.iter() { + let token_id = token_str.parse::().map_err(|_| { + ValidationError::LogitBiasInvalid(format!( + "Token ID {} is not a valid number.", + token_str + )) + })?; + + if token_id >= self.vocab_size { + return Err(ValidationError::LogitBiasInvalid(format!( + "Token ID {} is out of range. Must be between 0 and {}.", + token_id, + self.vocab_size - 1 + ))); + } + } + + // Transform into the required format + Some( + bias.iter() + .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) + .collect(), + ) + } + _ => None, + }; + let parameters = ValidParameters { temperature, repetition_penalty, @@ -420,18 +464,7 @@ impl Validation { seed, watermark, grammar, - logit_bias: Some( - request - .parameters - .logit_bias - .iter() - .flat_map(|bias| { - bias.iter() - .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) - .collect::>() - }) - .collect(), - ), + logit_bias, }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, @@ -1011,6 +1044,8 @@ pub enum ValidationError { FailedFetchImage(#[from] reqwest::Error), #[error("{0} modality is not supported")] UnsupportedModality(&'static str), + #[error("logit_bias is not valid: {0}")] + LogitBiasInvalid(String), } #[cfg(test)] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9f14b411..ad769990 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -625,55 +625,49 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): return self -class LogitBiasProcessor: - """Process logits with logit biases.""" +class LogitBiasProcessor(LogitsProcessor): + """ + `LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their + corresponding bias values. Bias are applied to the logits during each forward pass. + + Supports token IDs provided as strings (e.g., {"9707": -100}). + """ def __init__( - self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase + self, + logit_biases: dict, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ): - self.tokenizer = tokenizer - self.logit_biases = logit_biases or {} + assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases" - # Pre-compute token IDs for each token string - self.token_id_mapping = {} + vocab_size = len(tokenizer) + + # Convert keys to integers and values to a list + token_ids = torch.tensor( + [int(k) for k in logit_biases.keys()], dtype=torch.long + ) + bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float) + + # Create a tensor and directly copy bias values at the corresponding indices + self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float) + self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # If no logit biases, return scores unchanged - if not self.logit_biases: - return scores - - # Apply bias to the corresponding scores - for token_str, bias_value in self.logit_biases.items(): - # Get token ID, either from cache or by computing it - if token_str not in self.token_id_mapping: - if token_str.isdigit(): - # If the token string is already a numeric ID - token_id = int(token_str) - else: - # Otherwise, use the tokenizer to get the ID - tokens = self.tokenizer.encode(token_str, add_special_tokens=False) - token_id = tokens[0] if tokens else -1 # Use -1 for not found - - self.token_id_mapping[token_str] = token_id - - token_id = self.token_id_mapping[token_str] - - # Apply bias if token ID is valid - if 0 <= token_id < scores.size(-1): - scores[:, token_id] += bias_value - + # Apply bias tensor as a broadcasted addition + if self.bias_tensor.shape[0] != scores.shape[1]: + # Fix if the bias tensor is smaller than the scores + self.bias_tensor = torch.nn.functional.pad( + self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0]) + ) + scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype)) return scores - def filter(self, indices): - """Keep only the logit biases for the specified indices.""" - new_logit_biases = { - k: self.logit_biases[k] for k in indices if k in self.logit_biases - } - return LogitBiasProcessor(new_logit_biases, self.tokenizer) - -class HeterogeneousLogitBiasProcessor: - """Process logits with different logit biases for each sequence in the batch.""" +class HeterogeneousLogitBiasProcessor(LogitsProcessor): + """ + Process logits with different logit biases for each sequence in the batch. + """ def __init__( self, @@ -681,54 +675,42 @@ class HeterogeneousLogitBiasProcessor: tokenizer: PreTrainedTokenizerBase, device: torch.device, ): - self.device = device self.tokenizer = tokenizer self.logit_biases = logit_biases - self.batch_size = len(logit_biases) + # import ipdb; ipdb.set_trace() + self.vocab_size = len(tokenizer) - # Pre-compute token IDs for each token string - self.token_id_mapping = {} + # Create batch_size x vocab_size bias matrix + self.bias_matrix = torch.zeros( + (len(logit_biases), self.vocab_size), dtype=torch.float, device=device + ) - # Create a mapping of indices that have logit biases - self.indices_with_biases = { - i: bias_dict - for i, bias_dict in enumerate(self.logit_biases) - if bias_dict is not None and len(bias_dict) > 0 - } + # for each logit bias dictionary, convert keys to integers and values to a list + for i, logit_bias in enumerate(logit_biases): + token_ids = torch.tensor( + [int(k) for k in logit_bias.keys()], dtype=torch.long + ).to(device=device) + bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to( + device=device + ) + # Create a tensor and directly copy bias values at the corresponding indices + self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # If no indices have biases, return scores unchanged - if not self.indices_with_biases: - return scores - - # For each index with a bias, apply the bias to the corresponding scores - for i, bias_dict in self.indices_with_biases.items(): - for token_str, bias_value in bias_dict.items(): - # Get token ID, either from cache or by computing it - if token_str not in self.token_id_mapping: - if token_str.isdigit(): - # If the token string is already a numeric ID - token_id = int(token_str) - else: - # Otherwise, use the tokenizer to get the ID - tokens = self.tokenizer.encode( - token_str, add_special_tokens=False - ) - token_id = tokens[0] if tokens else -1 # Use -1 for not found - - self.token_id_mapping[token_str] = token_id - - token_id = self.token_id_mapping[token_str] - - # Apply bias if token ID is valid - if 0 <= token_id < scores.size(-1): - scores[i, token_id] += bias_value + # Apply bias matrix as a broadcasted addition + if self.bias_matrix.shape[1] != scores.shape[1]: + # Fix if the bias matrix is smaller than the scores + self.bias_matrix = torch.nn.functional.pad( + self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) + ) + scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype)) return scores - def filter(self, indices: List[int]): - """Keep only the logit biases for the specified indices.""" + def filter(self, indices): new_logit_biases = [self.logit_biases[i] for i in indices] + if not any(bias and len(bias) > 0 for bias in new_logit_biases): + return None return HeterogeneousLogitBiasProcessor( new_logit_biases, self.tokenizer, self.device ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index eeca7273..fa982c30 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -66,7 +66,6 @@ class NextTokenChooser: else None ) self.tokenizer = tokenizer - self.logit_bias = logit_bias has_warpers = ( (temperature is not None and temperature != 1.0) @@ -136,7 +135,7 @@ class NextTokenChooser: tokenizer=tokenizer, grammar=pb.grammar, grammar_type=pb.grammar_type, - logit_bias=dict(pb.logit_bias) if pb.logit_bias else None, + logit_bias=pb.logit_bias, ) @@ -264,10 +263,6 @@ class HeterogeneousNextTokenChooser: ): warpers = [] - # Initialize with empty logit biases if none provided - if logit_biases is None: - logit_biases = [None] * len(do_sample) - self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -306,7 +301,7 @@ class HeterogeneousNextTokenChooser: self.logit_bias_processor = ( HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device) - if any([bias is not None and len(bias) > 0 for bias in logit_biases]) + if any([logit_bias is not None for logit_bias in logit_biases]) else None ) @@ -530,9 +525,7 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), - logit_biases=[ - dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb - ], + logit_biases=[pb_.logit_bias for pb_ in pb], ) From b3ead6e95994cfbff11f131f2efe75d88eb04e7d Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 28 Apr 2025 13:59:23 +0000 Subject: [PATCH 08/14] fix: cleanup typos --- server/text_generation_server/utils/logits_process.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index ad769990..d106ce43 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -627,7 +627,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class LogitBiasProcessor(LogitsProcessor): """ - `LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their + `LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their corresponding bias values. Bias are applied to the logits during each forward pass. Supports token IDs provided as strings (e.g., {"9707": -100}). @@ -656,7 +656,7 @@ class LogitBiasProcessor(LogitsProcessor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: # Apply bias tensor as a broadcasted addition if self.bias_tensor.shape[0] != scores.shape[1]: - # Fix if the bias tensor is smaller than the scores + # Pad the bias matrix to match the scores if it's smaller self.bias_tensor = torch.nn.functional.pad( self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0]) ) @@ -699,7 +699,7 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: # Apply bias matrix as a broadcasted addition if self.bias_matrix.shape[1] != scores.shape[1]: - # Fix if the bias matrix is smaller than the scores + # Pad the bias matrix to match the scores if it's smaller self.bias_matrix = torch.nn.functional.pad( self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) ) From 465294d3de8c21bfaf83a47add66585d1ec48ea9 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 30 Apr 2025 20:12:05 +0000 Subject: [PATCH 09/14] fix: avoid zero'd logit bias mask --- server/text_generation_server/utils/tokens.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index fa982c30..e4a2698e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -301,7 +301,7 @@ class HeterogeneousNextTokenChooser: self.logit_bias_processor = ( HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device) - if any([logit_bias is not None for logit_bias in logit_biases]) + if any(logit_bias for logit_bias in logit_biases) else None ) From 7659925d8552173ccb74bb99d69b093dc1499b8a Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 May 2025 13:59:02 -0400 Subject: [PATCH 10/14] fix: improve validation and transform logic --- router/src/validation.rs | 54 +++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/router/src/validation.rs b/router/src/validation.rs index bda19224..b3d4dd9a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -424,34 +424,36 @@ impl Validation { None => None, }; - let logit_bias = match &request.parameters.logit_bias { - Some(bias) if !bias.is_empty() => { - for (token_str, _) in bias.iter() { - let token_id = token_str.parse::().map_err(|_| { - ValidationError::LogitBiasInvalid(format!( - "Token ID {} is not a valid number.", - token_str - )) - })?; + // Validate logit bias and convert to a vector of (token_id, bias_value) + let logit_bias = request + .parameters + .logit_bias + .as_ref() + .filter(|bias_map| !bias_map.is_empty()) + .map(|bias_map| { + bias_map + .iter() + .map(|(token_str, &bias_value)| { + let token_id: u32 = token_str.parse().map_err(|_| { + ValidationError::LogitBiasInvalid(format!( + "Token ID {token_str} is not a valid number." + )) + })?; - if token_id >= self.vocab_size { - return Err(ValidationError::LogitBiasInvalid(format!( - "Token ID {} is out of range. Must be between 0 and {}.", - token_id, - self.vocab_size - 1 - ))); - } - } + if token_id >= self.vocab_size { + return Err(ValidationError::LogitBiasInvalid(format!( + "Token ID {token_id} is out of range (0..{}).", + self.vocab_size - 1 + ))); + } - // Transform into the required format - Some( - bias.iter() - .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) - .collect(), - ) - } - _ => None, - }; + Ok((token_id, bias_value as f32)) + }) + .collect::, _>>() + }) + // convert Option> to Result, E> to throw + // if any of the token IDs are invalid + .transpose()?; let parameters = ValidParameters { temperature, From 55d82d4654fe367d166c944e7584b4b53b43a3bf Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 May 2025 21:45:18 +0000 Subject: [PATCH 11/14] fix: remove the bias padding --- .../text_generation_server/utils/logits_process.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index d106ce43..d9feb953 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -654,12 +654,6 @@ class LogitBiasProcessor(LogitsProcessor): self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # Apply bias tensor as a broadcasted addition - if self.bias_tensor.shape[0] != scores.shape[1]: - # Pad the bias matrix to match the scores if it's smaller - self.bias_tensor = torch.nn.functional.pad( - self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0]) - ) scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype)) return scores @@ -697,13 +691,6 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor): self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # Apply bias matrix as a broadcasted addition - if self.bias_matrix.shape[1] != scores.shape[1]: - # Pad the bias matrix to match the scores if it's smaller - self.bias_matrix = torch.nn.functional.pad( - self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) - ) - scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype)) return scores From b32cd97b71e9f4331fa7c5356667ffea67076a9a Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 May 2025 23:39:24 +0000 Subject: [PATCH 12/14] fix: read vocab size from tokenizer and add hacky patch for qwen2b --- .../test_logit_bias_english_to_spanish.json | 2 +- .../test_logit_bias_multiple_tokens.json | 2 +- .../test_logit_bias_streaming.json | 2 +- .../text_generation_server/models/flash_causal_lm.py | 9 +++++++++ .../text_generation_server/utils/logits_process.py | 12 ++++++++---- 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json index 3c4369f1..98a90c07 100644 --- a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json @@ -13,7 +13,7 @@ "usage": null } ], - "created": 1745337456, + "created": 1746486174, "id": "", "model": "Qwen/Qwen2-VL-2B-Instruct", "object": "chat.completion", diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json index ea8f3a83..47b48462 100644 --- a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json @@ -13,7 +13,7 @@ "usage": null } ], - "created": 1745337878, + "created": 1746486174, "id": "", "model": "Qwen/Qwen2-VL-2B-Instruct", "object": "chat.completion", diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json index 1668717e..c9288d1c 100644 --- a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json @@ -11,7 +11,7 @@ "logprobs": null } ], - "created": 1745337495, + "created": 1746486174, "id": "", "model": "Qwen/Qwen2-VL-2B-Instruct", "object": "chat.completion.chunk", diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a28ef381..a383cc88 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1267,6 +1267,15 @@ class FlashCausalLM(Model): prefix = None model = model_class(prefix, config, weights) + + if model.config.vocab_size != tokenizer.vocab_size: + logger.warning( + f"Tokenizer vocab size {tokenizer.vocab_size} does not match model vocab size {model.config.vocab_size}. Updating tokenizer vocab size." + ) + # TODO: HUGE HACK! This is a workaround for the fact that Qwen2TokenizerFast + # returns the incorrect vocab size for the 2B model. + tokenizer._vocab_size = model.config.vocab_size + torch.distributed.barrier(group=self.process_group) # VLM models define the config we care about in their text_config diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index d9feb953..36638c37 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -641,7 +641,8 @@ class LogitBiasProcessor(LogitsProcessor): ): assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases" - vocab_size = len(tokenizer) + # use _vocab_size or fallback to tokenizer.vocab_size if not available + self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size) # Convert keys to integers and values to a list token_ids = torch.tensor( @@ -650,7 +651,7 @@ class LogitBiasProcessor(LogitsProcessor): bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float) # Create a tensor and directly copy bias values at the corresponding indices - self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float) + self.bias_tensor = torch.zeros(self.vocab_size, dtype=torch.float) self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: @@ -669,10 +670,13 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor): tokenizer: PreTrainedTokenizerBase, device: torch.device, ): + assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases" + self.tokenizer = tokenizer self.logit_biases = logit_biases - # import ipdb; ipdb.set_trace() - self.vocab_size = len(tokenizer) + + # use _vocab_size or fallback to tokenizer.vocab_size if not available + self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size) # Create batch_size x vocab_size bias matrix self.bias_matrix = torch.zeros( From 783ca669267f1a2060c0836d082e12a619883727 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 May 2025 00:02:38 +0000 Subject: [PATCH 13/14] fix: prefer patch to be vlm specific --- server/text_generation_server/models/flash_causal_lm.py | 9 --------- server/text_generation_server/models/vlm_causal_lm.py | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a383cc88..a28ef381 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1267,15 +1267,6 @@ class FlashCausalLM(Model): prefix = None model = model_class(prefix, config, weights) - - if model.config.vocab_size != tokenizer.vocab_size: - logger.warning( - f"Tokenizer vocab size {tokenizer.vocab_size} does not match model vocab size {model.config.vocab_size}. Updating tokenizer vocab size." - ) - # TODO: HUGE HACK! This is a workaround for the fact that Qwen2TokenizerFast - # returns the incorrect vocab size for the 2B model. - tokenizer._vocab_size = model.config.vocab_size - torch.distributed.barrier(group=self.process_group) # VLM models define the config we care about in their text_config diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 2b1e01df..42588d3b 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -414,6 +414,15 @@ class VlmCausalLM(FlashCausalLM): **kwargs, ) + if self.config.vocab_size != self.tokenizer.vocab_size: + logger.warning( + f"Tokenizer vocab size {self.tokenizer.vocab_size} does not match model vocab size {self.config.vocab_size}. Updating tokenizer vocab size." + ) + # TODO: HUGE HACK! This is a workaround to update the vocab size + # in the tokenizer. When the tokenizer is updated within the model + # the vocab size is not updated in the tokenizer. + self.tokenizer._vocab_size = self.config.vocab_size + @property def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class From 551ee3a365d40047de6354ec98d3f3bfd540fe85 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 May 2025 00:03:17 +0000 Subject: [PATCH 14/14] fix: linter --- server/text_generation_server/models/vlm_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 42588d3b..7bc7e2bb 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -418,7 +418,7 @@ class VlmCausalLM(FlashCausalLM): logger.warning( f"Tokenizer vocab size {self.tokenizer.vocab_size} does not match model vocab size {self.config.vocab_size}. Updating tokenizer vocab size." ) - # TODO: HUGE HACK! This is a workaround to update the vocab size + # TODO: HUGE HACK! This is a workaround to update the vocab size # in the tokenizer. When the tokenizer is updated within the model # the vocab size is not updated in the tokenizer. self.tokenizer._vocab_size = self.config.vocab_size