feat: support logit bias in chat request

This commit is contained in:
drbh 2025-04-22 16:09:32 +00:00
parent 7253be349a
commit fae510b8f6
18 changed files with 363 additions and 10 deletions

View File

@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*; use pb::generate::v3::*;
use std::cmp::min; use std::cmp::min;
use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -181,6 +182,7 @@ impl Client {
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -5,6 +5,7 @@ use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input}; use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::join_all; use futures::future::join_all;
use std::collections::HashMap;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings}; use v3::client::{DecodeTimings, PrefillTimings};
@ -244,6 +245,7 @@ impl Health for ShardedClient {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*; use pb::generate::v3::*;
use std::cmp::min; use std::cmp::min;
use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -181,6 +182,7 @@ impl Client {
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -10,6 +10,7 @@ use crate::client::{
use crate::client::{Chunk, InfoResponse, Input}; use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::join_all; use futures::future::join_all;
use std::collections::HashMap;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -232,6 +233,7 @@ impl Health for ShardedClient {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -5,6 +5,7 @@ use crate::client::{
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max; use std::cmp::max;
use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
@ -522,6 +523,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
watermark: value.watermark, watermark: value.watermark,
grammar, grammar,
grammar_type: grammar_type.into(), grammar_type: grammar_type.into(),
logit_bias: value
.logit_bias
.map(|bias| {
bias.into_iter()
.map(|(token, bias)| (token.to_string(), bias as i32))
.collect::<HashMap<String, i32>>()
})
.unwrap_or_default(),
} }
} }
} }

View File

@ -47,6 +47,7 @@ pub async fn run(
watermark, watermark,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: std::collections::HashMap::new(),
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, field_validator, ConfigDict from pydantic import BaseModel, field_validator, ConfigDict
from typing import Optional, List, Union, Any from typing import Optional, List, Union, Any, Dict
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
# decreasing the model's likelihood to repeat the same line verbatim. # decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None frequency_penalty: Optional[float] = None
# Bias values for token selection # Bias values for token selection
logit_bias: Optional[List[float]] = None logit_bias: Optional[Dict[str, int]] = None
# Whether to return log probabilities # Whether to return log probabilities
logprobs: Optional[bool] = None logprobs: Optional[bool] = None
# Number of most likely tokens to return at each position # Number of most likely tokens to return at each position

View File

@ -995,12 +995,12 @@
"nullable": true "nullable": true
}, },
"logit_bias": { "logit_bias": {
"type": "array", "type": "object",
"items": { "description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"type": "number", "additionalProperties": {
"format": "float" "type": "integer",
"format": "int32"
}, },
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"nullable": true "nullable": true
}, },
"logprobs": { "logprobs": {
@ -1589,6 +1589,17 @@
"default": "null", "default": "null",
"nullable": true "nullable": true
}, },
"logit_bias": {
"type": "object",
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
"default": "null",
"additionalProperties": {
"type": "integer",
"format": "int32"
},
"example": "{\"1923\": 100, \"1924\": -100}",
"nullable": true
},
"max_new_tokens": { "max_new_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Hello! How can I help you today?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1745337495,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 21,
"total_tokens": 31
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "¡Hola! ¿Cómo puedo ayudarte?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1745337456,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 21,
"total_tokens": 31
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Chat!",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1745337878,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 3,
"prompt_tokens": 25,
"total_tokens": 28
}
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1745337495,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.2.3-dev0-native",
"usage": null
}

View File

@ -0,0 +1,109 @@
import pytest
@pytest.fixture(scope="module")
def logit_bias_model_handle(launcher):
with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def logit_bias_model(logit_bias_model_handle):
await logit_bias_model_handle.health(300)
return logit_bias_model_handle.client
@pytest.mark.private
async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot):
"""Test that setting negative bias on English tokens forces output to be in Spanish"""
response = await logit_bias_model.chat(
seed=42,
max_tokens=10,
logit_bias={"9707": -100}, # Bias against 'Hello' token
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "say Hello"},
],
},
],
)
assert "¡Hola!" in response.choices[0].message.content
assert "Hello" not in response.choices[0].message.content
assert response == response_snapshot
@pytest.mark.private
async def test_logit_bias_baseline(logit_bias_model, response_snapshot):
"""Test baseline behavior without logit bias for comparison"""
response = await logit_bias_model.chat(
seed=42,
max_tokens=10,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "say Hello"},
],
},
],
)
assert "Hello" in response.choices[0].message.content
assert response == response_snapshot
@pytest.mark.private
async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot):
"""Test applying bias to multiple tokens simultaneously"""
response = await logit_bias_model.chat(
seed=42,
max_tokens=15,
logit_bias={
"9707": -100, # Bias against 'Hello' token
"2880": -100, # Bias against 'hi' token
},
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Give me a one-word greeting"},
],
},
],
)
assert "Hello" not in response.choices[0].message.content.lower()
assert "hi" not in response.choices[0].message.content.lower()
assert response == response_snapshot
@pytest.mark.private
async def test_logit_bias_streaming(logit_bias_model, response_snapshot):
"""Test logit bias works correctly with streaming enabled"""
responses = await logit_bias_model.chat(
seed=42,
max_tokens=10,
logit_bias={"9707": -100}, # Bias against 'Hello' token
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "say Hello"},
],
},
],
stream=True,
)
count = 0
generated = ""
last_response = None
async for response in responses:
count += 1
generated += response.choices[0].delta.content
last_response = response
assert "¡Hola!" in generated
assert "Hello" not in generated
assert last_response == response_snapshot

View File

@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokenizers::Encoding; use tokenizers::Encoding;
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>, pub adapter_id: Option<String>,
/// Modify the likelihood of specified tokens appearing in the completion.
/// Accepts a hash map that maps token strings to an associated bias value.
#[serde(default)]
#[schema(
nullable = true,
default = "null",
example = "{\"1923\": 100, \"1924\": -100}"
)]
pub logit_bias: Option<HashMap<String, i32>>,
} }
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters {
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_id: None, adapter_id: None,
logit_bias: None,
} }
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
#[serde(try_from = "PromptDeserializer")] #[serde(try_from = "PromptDeserializer")]
pub struct Prompt(pub Vec<String>); pub struct Prompt(pub Vec<String>);
@ -841,14 +852,13 @@ pub(crate) struct ChatRequest {
#[schema(example = "1.0")] #[schema(example = "1.0")]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token. /// result in a ban or exclusive selection of the relevant token.
#[serde(default)] #[serde(default)]
pub logit_bias: Option<Vec<f32>>, pub logit_bias: Option<HashMap<String, i32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. /// output token returned in the content of message.
@ -954,6 +964,7 @@ impl ChatRequest {
frequency_penalty, frequency_penalty,
top_p, top_p,
top_logprobs, top_logprobs,
logit_bias,
.. ..
} = self; } = self;
@ -1029,6 +1040,7 @@ impl ChatRequest {
top_n_tokens: top_logprobs, top_n_tokens: top_logprobs,
grammar, grammar,
adapter_id: model.filter(|m| *m != "tgi"), adapter_id: model.filter(|m| *m != "tgi"),
logit_bias,
}, },
}, },
using_tools, using_tools,

View File

@ -798,6 +798,7 @@ pub(crate) async fn completions(
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
logit_bias: None,
}, },
}) })
.collect(); .collect();
@ -1191,6 +1192,7 @@ pub(crate) async fn chat_completions(
let (generate_request, using_tools): (GenerateRequest, bool) = let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?; chat.clone().try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters)); span.record("parameters", format!("{:?}", generate_request.parameters));
println!("ChatRequest: {:#?}", generate_request);
let logprobs = logprobs.unwrap_or_default(); let logprobs = logprobs.unwrap_or_default();
// extract model id from request if specified // extract model id from request if specified

View File

@ -420,6 +420,18 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
logit_bias: Some(
request
.parameters
.logit_bias
.iter()
.flat_map(|bias| {
bias.iter()
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
.collect::<Vec<_>>()
})
.collect(),
),
}; };
let stopping_parameters = ValidStoppingParameters { let stopping_parameters = ValidStoppingParameters {
max_new_tokens, max_new_tokens,
@ -902,6 +914,8 @@ pub struct ValidParameters {
pub watermark: bool, pub watermark: bool,
/// / grammar (applied if not empty) /// / grammar (applied if not empty)
pub grammar: Option<ValidGrammar>, pub grammar: Option<ValidGrammar>,
/// / logit bias
pub logit_bias: Option<Vec<(u32, f32)>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View File

@ -623,3 +623,65 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
new_fsms.append(self.fsms[i]) new_fsms.append(self.fsms[i])
self.fsms = new_fsms self.fsms = new_fsms
return self return self
class HeterogeneousLogitBiasProcessor:
"""Process logits with different logit biases for each sequence in the batch."""
def __init__(
self,
logit_biases: List[Optional[dict]],
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
):
self.device = device
self.tokenizer = tokenizer
self.logit_biases = logit_biases
self.batch_size = len(logit_biases)
# Pre-compute token IDs for each token string
self.token_id_mapping = {}
# Create a mapping of indices that have logit biases
self.indices_with_biases = {
i: bias_dict
for i, bias_dict in enumerate(self.logit_biases)
if bias_dict is not None and len(bias_dict) > 0
}
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# If no indices have biases, return scores unchanged
if not self.indices_with_biases:
return scores
# For each index with a bias, apply the bias to the corresponding scores
for i, bias_dict in self.indices_with_biases.items():
for token_str, bias_value in bias_dict.items():
# Get token ID, either from cache or by computing it
if token_str not in self.token_id_mapping:
if token_str.isdigit():
# If the token string is already a numeric ID
token_id = int(token_str)
else:
# Otherwise, use the tokenizer to get the ID
tokens = self.tokenizer.encode(
token_str, add_special_tokens=False
)
token_id = tokens[0] if tokens else -1 # Use -1 for not found
self.token_id_mapping[token_str] = token_id
token_id = self.token_id_mapping[token_str]
# Apply bias if token ID is valid
if 0 <= token_id < scores.size(-1):
scores[i, token_id] += bias_value
return scores
def filter(self, indices: List[int]):
"""Keep only the logit biases for the specified indices."""
new_logit_biases = [self.logit_biases[i] for i in indices]
return HeterogeneousLogitBiasProcessor(
new_logit_biases, self.tokenizer, self.device
)

View File

@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import (
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor, HeterogeneousGrammarLogitProcessor,
HeterogeneousLogitBiasProcessor,
static_warper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
@ -38,6 +39,7 @@ class NextTokenChooser:
grammar: str = "", grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0, fsm_grammar_state: int = 0,
logit_bias: Optional[dict] = None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -58,6 +60,7 @@ class NextTokenChooser:
else None else None
) )
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.logit_bias = logit_bias
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
@ -87,6 +90,8 @@ class NextTokenChooser:
scores = self.frequency_processor(input_ids, scores) scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state) scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.logit_bias_processor is not None:
scores = self.logit_bias_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -125,6 +130,7 @@ class NextTokenChooser:
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar, grammar=pb.grammar,
grammar_type=pb.grammar_type, grammar_type=pb.grammar_type,
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
) )
@ -248,9 +254,14 @@ class HeterogeneousNextTokenChooser:
grammars: List[str], grammars: List[str],
grammar_types: List[int], grammar_types: List[int],
fsm_grammar_states=List[int], fsm_grammar_states=List[int],
logit_biases: List[Optional[dict]] = None,
): ):
warpers = [] warpers = []
# Initialize with empty logit biases if none provided
if logit_biases is None:
logit_biases = [None] * len(do_sample)
self.watermark_processor = ( self.watermark_processor = (
HeterogeneousProcessorWrapper( HeterogeneousProcessorWrapper(
{ {
@ -287,6 +298,12 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.logit_bias_processor = (
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
if any([bias is not None and len(bias) > 0 for bias in logit_biases])
else None
)
if any(x != 1.0 for x in temperature): if any(x != 1.0 for x in temperature):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -322,6 +339,7 @@ class HeterogeneousNextTokenChooser:
self.fsm_grammar_states = fsm_grammar_states self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars self.grammars = grammars
self.grammar_types = grammar_types self.grammar_types = grammar_types
self.logit_biases = logit_biases
def __call__( def __call__(
self, self,
@ -353,6 +371,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.frequency_processor(input_ids, _scores) _scores = self.frequency_processor(input_ids, _scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states) _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
if self.logit_bias_processor is not None:
_scores = self.logit_bias_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
@ -444,6 +464,9 @@ class HeterogeneousNextTokenChooser:
if self.grammar_processor is not None: if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices) self.grammar_processor = self.grammar_processor.filter(indices)
if self.logit_bias_processor is not None:
self.logit_bias_processor = self.logit_bias_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
@ -453,6 +476,7 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
self.logit_biases = [self.logit_biases[i] for i in indices]
new_grammars = [] new_grammars = []
new_fsm_grammar_states = [] new_fsm_grammar_states = []
@ -500,6 +524,9 @@ class HeterogeneousNextTokenChooser:
fsm_grammar_states=( fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb) fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
), ),
logit_biases=[
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
],
) )