From fae510b8f6128bca2b864c5410fca54d2c25d78c Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 22 Apr 2025 16:09:32 +0000 Subject: [PATCH] feat: support logit bias in chat request --- backends/client/src/v3/client.rs | 2 + backends/client/src/v3/sharded_client.rs | 2 + backends/v3/src/client/grpc_client.rs | 2 + backends/v3/src/client/sharded_client.rs | 2 + backends/v3/src/queue.rs | 9 ++ benchmark/src/lib.rs | 1 + clients/python/text_generation/types.py | 4 +- docs/openapi.json | 21 +++- .../test_logit_bias_baseline.json | 26 +++++ .../test_logit_bias_english_to_spanish.json | 26 +++++ .../test_logit_bias_multiple_tokens.json | 26 +++++ .../test_logit_bias_streaming.json | 20 ++++ .../models/test_flash_logit_bias.py | 109 ++++++++++++++++++ router/src/lib.rs | 18 ++- router/src/server.rs | 2 + router/src/validation.rs | 14 +++ .../utils/logits_process.py | 62 ++++++++++ server/text_generation_server/utils/tokens.py | 27 +++++ 18 files changed, 363 insertions(+), 10 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json create mode 100644 integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json create mode 100644 integration-tests/models/test_flash_logit_bias.py diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 968c1f45..b5dd081f 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::collections::HashMap; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -181,6 +182,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index dc3bcdde..99688668 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -5,6 +5,7 @@ use crate::{ClientError, Result}; use crate::v3::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; +use std::collections::HashMap; use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; @@ -244,6 +245,7 @@ impl Health for ShardedClient { watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index f4942f64..9d5e042a 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::collections::HashMap; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -181,6 +182,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 4701c560..b9667dc4 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -10,6 +10,7 @@ use crate::client::{ use crate::client::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; +use std::collections::HashMap; use tonic::transport::Uri; use tracing::instrument; @@ -232,6 +233,7 @@ impl Health for ShardedClient { watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index d3bf4b9c..4c143deb 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -5,6 +5,7 @@ use crate::client::{ }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; +use std::collections::HashMap; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; @@ -522,6 +523,14 @@ impl From for NextTokenChooserParameters { watermark: value.watermark, grammar, grammar_type: grammar_type.into(), + logit_bias: value + .logit_bias + .map(|bias| { + bias.into_iter() + .map(|(token, bias)| (token.to_string(), bias as i32)) + .collect::>() + }) + .unwrap_or_default(), } } } diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index bb4b6a77..cd1f9446 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -47,6 +47,7 @@ pub async fn run( watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: std::collections::HashMap::new(), }; // Initialize terminal properties diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 6f51c153..b9e00fae 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, field_validator, ConfigDict -from typing import Optional, List, Union, Any +from typing import Optional, List, Union, Any, Dict from text_generation.errors import ValidationError @@ -137,7 +137,7 @@ class ChatRequest(BaseModel): # decreasing the model's likelihood to repeat the same line verbatim. frequency_penalty: Optional[float] = None # Bias values for token selection - logit_bias: Optional[List[float]] = None + logit_bias: Optional[Dict[str, int]] = None # Whether to return log probabilities logprobs: Optional[bool] = None # Number of most likely tokens to return at each position diff --git a/docs/openapi.json b/docs/openapi.json index 84ba5885..22145191 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -995,12 +995,12 @@ "nullable": true }, "logit_bias": { - "type": "array", - "items": { - "type": "number", - "format": "float" + "type": "object", + "description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", + "additionalProperties": { + "type": "integer", + "format": "int32" }, - "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", "nullable": true }, "logprobs": { @@ -1589,6 +1589,17 @@ "default": "null", "nullable": true }, + "logit_bias": { + "type": "object", + "description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.", + "default": "null", + "additionalProperties": { + "type": "integer", + "format": "int32" + }, + "example": "{\"1923\": 100, \"1924\": -100}", + "nullable": true + }, "max_new_tokens": { "type": "integer", "format": "int32", diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json new file mode 100644 index 00000000..b683b090 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello! How can I help you today?", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337495, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 21, + "total_tokens": 31 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json new file mode 100644 index 00000000..3c4369f1 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "¡Hola! ¿Cómo puedo ayudarte?", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337456, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 21, + "total_tokens": 31 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json new file mode 100644 index 00000000..ea8f3a83 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "Chat!", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337878, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 3, + "prompt_tokens": 25, + "total_tokens": 28 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json new file mode 100644 index 00000000..1668717e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1745337495, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "3.2.3-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_logit_bias.py b/integration-tests/models/test_flash_logit_bias.py new file mode 100644 index 00000000..5ce793b7 --- /dev/null +++ b/integration-tests/models/test_flash_logit_bias.py @@ -0,0 +1,109 @@ +import pytest + + +@pytest.fixture(scope="module") +def logit_bias_model_handle(launcher): + with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def logit_bias_model(logit_bias_model_handle): + await logit_bias_model_handle.health(300) + return logit_bias_model_handle.client + + +@pytest.mark.private +async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot): + """Test that setting negative bias on English tokens forces output to be in Spanish""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=10, + logit_bias={"9707": -100}, # Bias against 'Hello' token + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + ) + assert "¡Hola!" in response.choices[0].message.content + assert "Hello" not in response.choices[0].message.content + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_baseline(logit_bias_model, response_snapshot): + """Test baseline behavior without logit bias for comparison""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=10, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + ) + assert "Hello" in response.choices[0].message.content + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot): + """Test applying bias to multiple tokens simultaneously""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=15, + logit_bias={ + "9707": -100, # Bias against 'Hello' token + "2880": -100, # Bias against 'hi' token + }, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Give me a one-word greeting"}, + ], + }, + ], + ) + assert "Hello" not in response.choices[0].message.content.lower() + assert "hi" not in response.choices[0].message.content.lower() + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_streaming(logit_bias_model, response_snapshot): + """Test logit bias works correctly with streaming enabled""" + responses = await logit_bias_model.chat( + seed=42, + max_tokens=10, + logit_bias={"9707": -100}, # Bias against 'Hello' token + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert "¡Hola!" in generated + assert "Hello" not in generated + assert last_response == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc2..88ab86e9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError}; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; @@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub adapter_id: Option, + + /// Modify the likelihood of specified tokens appearing in the completion. + /// Accepts a hash map that maps token strings to an associated bias value. + #[serde(default)] + #[schema( + nullable = true, + default = "null", + example = "{\"1923\": 100, \"1924\": -100}" + )] + pub logit_bias: Option>, } fn default_parameters() -> GenerateParameters { @@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters { top_n_tokens: None, grammar: None, adapter_id: None, + logit_bias: None, } } - #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[serde(try_from = "PromptDeserializer")] pub struct Prompt(pub Vec); @@ -841,14 +852,13 @@ pub(crate) struct ChatRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, - /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should /// result in a ban or exclusive selection of the relevant token. #[serde(default)] - pub logit_bias: Option>, + pub logit_bias: Option>, /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// output token returned in the content of message. @@ -954,6 +964,7 @@ impl ChatRequest { frequency_penalty, top_p, top_logprobs, + logit_bias, .. } = self; @@ -1029,6 +1040,7 @@ impl ChatRequest { top_n_tokens: top_logprobs, grammar, adapter_id: model.filter(|m| *m != "tgi"), + logit_bias, }, }, using_tools, diff --git a/router/src/server.rs b/router/src/server.rs index 001a85e0..4f9fdc87 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -798,6 +798,7 @@ pub(crate) async fn completions( top_n_tokens: None, grammar: None, adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + logit_bias: None, }, }) .collect(); @@ -1191,6 +1192,7 @@ pub(crate) async fn chat_completions( let (generate_request, using_tools): (GenerateRequest, bool) = chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); + println!("ChatRequest: {:#?}", generate_request); let logprobs = logprobs.unwrap_or_default(); // extract model id from request if specified diff --git a/router/src/validation.rs b/router/src/validation.rs index 28c7f2f8..6379b145 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -420,6 +420,18 @@ impl Validation { seed, watermark, grammar, + logit_bias: Some( + request + .parameters + .logit_bias + .iter() + .flat_map(|bias| { + bias.iter() + .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) + .collect::>() + }) + .collect(), + ), }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, @@ -902,6 +914,8 @@ pub struct ValidParameters { pub watermark: bool, /// / grammar (applied if not empty) pub grammar: Option, + /// / logit bias + pub logit_bias: Option>, } #[derive(Debug, Clone)] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 64a285b9..b0dfe571 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -623,3 +623,65 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): new_fsms.append(self.fsms[i]) self.fsms = new_fsms return self + + +class HeterogeneousLogitBiasProcessor: + """Process logits with different logit biases for each sequence in the batch.""" + + def __init__( + self, + logit_biases: List[Optional[dict]], + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ): + self.device = device + self.tokenizer = tokenizer + self.logit_biases = logit_biases + self.batch_size = len(logit_biases) + + # Pre-compute token IDs for each token string + self.token_id_mapping = {} + + # Create a mapping of indices that have logit biases + self.indices_with_biases = { + i: bias_dict + for i, bias_dict in enumerate(self.logit_biases) + if bias_dict is not None and len(bias_dict) > 0 + } + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If no indices have biases, return scores unchanged + if not self.indices_with_biases: + return scores + + # For each index with a bias, apply the bias to the corresponding scores + for i, bias_dict in self.indices_with_biases.items(): + for token_str, bias_value in bias_dict.items(): + # Get token ID, either from cache or by computing it + if token_str not in self.token_id_mapping: + if token_str.isdigit(): + # If the token string is already a numeric ID + token_id = int(token_str) + else: + # Otherwise, use the tokenizer to get the ID + tokens = self.tokenizer.encode( + token_str, add_special_tokens=False + ) + token_id = tokens[0] if tokens else -1 # Use -1 for not found + + self.token_id_mapping[token_str] = token_id + + token_id = self.token_id_mapping[token_str] + + # Apply bias if token ID is valid + if 0 <= token_id < scores.size(-1): + scores[i, token_id] += bias_value + + return scores + + def filter(self, indices: List[int]): + """Keep only the logit biases for the specified indices.""" + new_logit_biases = [self.logit_biases[i] for i in indices] + return HeterogeneousLogitBiasProcessor( + new_logit_biases, self.tokenizer, self.device + ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665..8c916bfd 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousGrammarLogitProcessor, + HeterogeneousLogitBiasProcessor, static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -38,6 +39,7 @@ class NextTokenChooser: grammar: str = "", grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, fsm_grammar_state: int = 0, + logit_bias: Optional[dict] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -58,6 +60,7 @@ class NextTokenChooser: else None ) self.tokenizer = tokenizer + self.logit_bias = logit_bias has_warpers = ( (temperature is not None and temperature != 1.0) @@ -87,6 +90,8 @@ class NextTokenChooser: scores = self.frequency_processor(input_ids, scores) if self.grammar_processor is not None: scores = self.grammar_processor(scores, self.fsm_grammar_state) + if self.logit_bias_processor is not None: + scores = self.logit_bias_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -125,6 +130,7 @@ class NextTokenChooser: tokenizer=tokenizer, grammar=pb.grammar, grammar_type=pb.grammar_type, + logit_bias=dict(pb.logit_bias) if pb.logit_bias else None, ) @@ -248,9 +254,14 @@ class HeterogeneousNextTokenChooser: grammars: List[str], grammar_types: List[int], fsm_grammar_states=List[int], + logit_biases: List[Optional[dict]] = None, ): warpers = [] + # Initialize with empty logit biases if none provided + if logit_biases is None: + logit_biases = [None] * len(do_sample) + self.watermark_processor = ( HeterogeneousProcessorWrapper( { @@ -287,6 +298,12 @@ class HeterogeneousNextTokenChooser: else None ) + self.logit_bias_processor = ( + HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device) + if any([bias is not None and len(bias) > 0 for bias in logit_biases]) + else None + ) + if any(x != 1.0 for x in temperature): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -322,6 +339,7 @@ class HeterogeneousNextTokenChooser: self.fsm_grammar_states = fsm_grammar_states self.grammars = grammars self.grammar_types = grammar_types + self.logit_biases = logit_biases def __call__( self, @@ -353,6 +371,8 @@ class HeterogeneousNextTokenChooser: _scores = self.frequency_processor(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + if self.logit_bias_processor is not None: + _scores = self.logit_bias_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) @@ -444,6 +464,9 @@ class HeterogeneousNextTokenChooser: if self.grammar_processor is not None: self.grammar_processor = self.grammar_processor.filter(indices) + if self.logit_bias_processor is not None: + self.logit_bias_processor = self.logit_bias_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) @@ -453,6 +476,7 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + self.logit_biases = [self.logit_biases[i] for i in indices] new_grammars = [] new_fsm_grammar_states = [] @@ -500,6 +524,9 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), + logit_biases=[ + dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb + ], )