This commit is contained in:
drbh 2025-06-15 00:49:23 +02:00 committed by GitHub
commit 0381aba864
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 431 additions and 36 deletions

View File

@ -1,26 +0,0 @@
name: Python Client Tests
on:
pull_request:
paths:
- ".github/workflows/client-tests.yaml"
- "clients/python/**"
jobs:
run_tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.9
- name: Install
run: |
cd clients/python && pip install .
- name: Run tests
run: |
pip install pytest pytest-asyncio
export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests

View File

@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*; use pb::generate::v3::*;
use std::cmp::min; use std::cmp::min;
use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -181,6 +182,7 @@ impl Client {
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -5,6 +5,7 @@ use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input}; use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::join_all; use futures::future::join_all;
use std::collections::HashMap;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings}; use v3::client::{DecodeTimings, PrefillTimings};
@ -244,6 +245,7 @@ impl Health for ShardedClient {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -429,6 +429,7 @@ mod tests {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: None, grammar: None,
logit_bias: None,
}, },
stopping_parameters: ValidStoppingParameters { stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*; use pb::generate::v3::*;
use std::cmp::min; use std::cmp::min;
use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -181,6 +182,7 @@ impl Client {
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -10,6 +10,7 @@ use crate::client::{
use crate::client::{Chunk, InfoResponse, Input}; use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::join_all; use futures::future::join_all;
use std::collections::HashMap;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -232,6 +233,7 @@ impl Health for ShardedClient {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: HashMap::new(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -5,6 +5,7 @@ use crate::client::{
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max; use std::cmp::max;
use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
@ -542,6 +543,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
watermark: value.watermark, watermark: value.watermark,
grammar, grammar,
grammar_type: grammar_type.into(), grammar_type: grammar_type.into(),
logit_bias: value
.logit_bias
.map(|bias| {
bias.into_iter()
.map(|(token, bias)| (token.to_string(), bias as i32))
.collect::<HashMap<String, i32>>()
})
.unwrap_or_default(),
} }
} }
} }
@ -588,6 +597,7 @@ mod tests {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: None, grammar: None,
logit_bias: None,
}, },
stopping_parameters: ValidStoppingParameters { stopping_parameters: ValidStoppingParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -47,6 +47,7 @@ pub async fn run(
watermark, watermark,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
logit_bias: std::collections::HashMap::new(),
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, field_validator, ConfigDict from pydantic import BaseModel, field_validator, ConfigDict
from typing import Optional, List, Union, Any from typing import Optional, List, Union, Any, Dict
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
# decreasing the model's likelihood to repeat the same line verbatim. # decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None frequency_penalty: Optional[float] = None
# Bias values for token selection # Bias values for token selection
logit_bias: Optional[List[float]] = None logit_bias: Optional[Dict[str, int]] = None
# Whether to return log probabilities # Whether to return log probabilities
logprobs: Optional[bool] = None logprobs: Optional[bool] = None
# Number of most likely tokens to return at each position # Number of most likely tokens to return at each position

View File

@ -995,12 +995,12 @@
"nullable": true "nullable": true
}, },
"logit_bias": { "logit_bias": {
"type": "array", "type": "object",
"items": { "description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"type": "number", "additionalProperties": {
"format": "float" "type": "integer",
"format": "int32"
}, },
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
"nullable": true "nullable": true
}, },
"logprobs": { "logprobs": {
@ -1589,6 +1589,17 @@
"default": "null", "default": "null",
"nullable": true "nullable": true
}, },
"logit_bias": {
"type": "object",
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
"default": "null",
"additionalProperties": {
"type": "integer",
"format": "int32"
},
"example": "{\"1923\": 100, \"1924\": -100}",
"nullable": true
},
"max_new_tokens": { "max_new_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Hello! How can I help you today?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1745337495,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 21,
"total_tokens": 31
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "¡Hola! ¿Cómo puedo ayudarte?",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1746486174,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 21,
"total_tokens": 31
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Chat!",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1746486174,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 3,
"prompt_tokens": 25,
"total_tokens": 28
}
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1746486174,
"id": "",
"model": "Qwen/Qwen2-VL-2B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.2.3-dev0-native",
"usage": null
}

View File

@ -0,0 +1,109 @@
import pytest
@pytest.fixture(scope="module")
def logit_bias_model_handle(launcher):
with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def logit_bias_model(logit_bias_model_handle):
await logit_bias_model_handle.health(300)
return logit_bias_model_handle.client
@pytest.mark.private
async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot):
"""Test that setting negative bias on English tokens forces output to be in Spanish"""
response = await logit_bias_model.chat(
seed=42,
max_tokens=10,
logit_bias={"9707": -100}, # Bias against 'Hello' token
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "say Hello"},
],
},
],
)
assert "¡Hola!" in response.choices[0].message.content
assert "Hello" not in response.choices[0].message.content
assert response == response_snapshot
@pytest.mark.private
async def test_logit_bias_baseline(logit_bias_model, response_snapshot):
"""Test baseline behavior without logit bias for comparison"""
response = await logit_bias_model.chat(
seed=42,
max_tokens=10,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "say Hello"},
],
},
],
)
assert "Hello" in response.choices[0].message.content
assert response == response_snapshot
@pytest.mark.private
async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot):
"""Test applying bias to multiple tokens simultaneously"""
response = await logit_bias_model.chat(
seed=42,
max_tokens=15,
logit_bias={
"9707": -100, # Bias against 'Hello' token
"2880": -100, # Bias against 'hi' token
},
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Give me a one-word greeting"},
],
},
],
)
assert "Hello" not in response.choices[0].message.content.lower()
assert "hi" not in response.choices[0].message.content.lower()
assert response == response_snapshot
@pytest.mark.private
async def test_logit_bias_streaming(logit_bias_model, response_snapshot):
"""Test logit bias works correctly with streaming enabled"""
responses = await logit_bias_model.chat(
seed=42,
max_tokens=10,
logit_bias={"9707": -100}, # Bias against 'Hello' token
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "say Hello"},
],
},
],
stream=True,
)
count = 0
generated = ""
last_response = None
async for response in responses:
count += 1
generated += response.choices[0].delta.content
last_response = response
assert "¡Hola!" in generated
assert "Hello" not in generated
assert last_response == response_snapshot

View File

@ -104,6 +104,8 @@ message NextTokenChooserParameters {
string grammar = 10; string grammar = 10;
/// grammar type /// grammar type
GrammarType grammar_type = 11; GrammarType grammar_type = 11;
/// logit bias dictionary mapping token string to bias value
map<string, int32> logit_bias = 12;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokenizers::Encoding; use tokenizers::Encoding;
use tracing::warn; use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>, pub adapter_id: Option<String>,
/// Modify the likelihood of specified tokens appearing in the completion.
/// Accepts a hash map that maps token strings to an associated bias value.
#[serde(default)]
#[schema(
nullable = true,
default = "null",
example = "{\"1923\": 100, \"1924\": -100}"
)]
pub logit_bias: Option<HashMap<String, i32>>,
} }
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters {
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_id: None, adapter_id: None,
logit_bias: None,
} }
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
#[serde(try_from = "PromptDeserializer")] #[serde(try_from = "PromptDeserializer")]
pub struct Prompt(pub Vec<String>); pub struct Prompt(pub Vec<String>);
@ -841,14 +852,13 @@ pub(crate) struct ChatRequest {
#[schema(example = "1.0")] #[schema(example = "1.0")]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token. /// result in a ban or exclusive selection of the relevant token.
#[serde(default)] #[serde(default)]
pub logit_bias: Option<Vec<f32>>, pub logit_bias: Option<HashMap<String, i32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. /// output token returned in the content of message.
@ -954,6 +964,7 @@ impl ChatRequest {
frequency_penalty, frequency_penalty,
top_p, top_p,
top_logprobs, top_logprobs,
logit_bias,
.. ..
} = self; } = self;
@ -1029,6 +1040,7 @@ impl ChatRequest {
top_n_tokens: top_logprobs, top_n_tokens: top_logprobs,
grammar, grammar,
adapter_id: model.filter(|m| *m != "tgi"), adapter_id: model.filter(|m| *m != "tgi"),
logit_bias,
}, },
}, },
using_tools, using_tools,

View File

@ -798,6 +798,7 @@ pub(crate) async fn completions(
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
logit_bias: None,
}, },
}) })
.collect(); .collect();

View File

@ -34,6 +34,7 @@ pub struct Validation {
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
disable_grammar_support: bool, disable_grammar_support: bool,
vocab_size: u32,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: mpsc::UnboundedSender<TokenizerRequest>, sender: mpsc::UnboundedSender<TokenizerRequest>,
} }
@ -88,6 +89,19 @@ impl Validation {
validation_sender validation_sender
}; };
let vocab_size = match &tokenizer {
Tokenizer::Python { tokenizer_name, .. } => {
warn!(
"Tokenizer {} is not supported for validation",
tokenizer_name
);
0
}
Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false),
}
.try_into()
.unwrap_or(0);
Self { Self {
max_best_of, max_best_of,
sender, sender,
@ -96,6 +110,7 @@ impl Validation {
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support, disable_grammar_support,
vocab_size,
} }
} }
@ -409,6 +424,37 @@ impl Validation {
None => None, None => None,
}; };
// Validate logit bias and convert to a vector of (token_id, bias_value)
let logit_bias = request
.parameters
.logit_bias
.as_ref()
.filter(|bias_map| !bias_map.is_empty())
.map(|bias_map| {
bias_map
.iter()
.map(|(token_str, &bias_value)| {
let token_id: u32 = token_str.parse().map_err(|_| {
ValidationError::LogitBiasInvalid(format!(
"Token ID {token_str} is not a valid number."
))
})?;
if token_id >= self.vocab_size {
return Err(ValidationError::LogitBiasInvalid(format!(
"Token ID {token_id} is out of range (0..{}).",
self.vocab_size - 1
)));
}
Ok((token_id, bias_value as f32))
})
.collect::<Result<Vec<_>, _>>()
})
// convert Option<Result<T, E>> to Result<Option<T>, E> to throw
// if any of the token IDs are invalid
.transpose()?;
let parameters = ValidParameters { let parameters = ValidParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
@ -420,6 +466,7 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
logit_bias,
}; };
let stopping_parameters = ValidStoppingParameters { let stopping_parameters = ValidStoppingParameters {
max_new_tokens, max_new_tokens,
@ -902,6 +949,8 @@ pub struct ValidParameters {
pub watermark: bool, pub watermark: bool,
/// / grammar (applied if not empty) /// / grammar (applied if not empty)
pub grammar: Option<ValidGrammar>, pub grammar: Option<ValidGrammar>,
/// / logit bias
pub logit_bias: Option<Vec<(u32, f32)>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -997,6 +1046,8 @@ pub enum ValidationError {
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")] #[error("{0} modality is not supported")]
UnsupportedModality(&'static str), UnsupportedModality(&'static str),
#[error("logit_bias is not valid: {0}")]
LogitBiasInvalid(String),
} }
#[cfg(test)] #[cfg(test)]

View File

@ -725,6 +725,15 @@ class VlmCausalLM(FlashCausalLM):
**kwargs, **kwargs,
) )
if self.config.vocab_size != self.tokenizer.vocab_size:
logger.warning(
f"Tokenizer vocab size {self.tokenizer.vocab_size} does not match model vocab size {self.config.vocab_size}. Updating tokenizer vocab size."
)
# TODO: HUGE HACK! This is a workaround to update the vocab size
# in the tokenizer. When the tokenizer is updated within the model
# the vocab size is not updated in the tokenizer.
self.tokenizer._vocab_size = self.config.vocab_size
@property @property
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return self.batch_class return self.batch_class

View File

@ -623,3 +623,85 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
new_fsms.append(self.fsms[i]) new_fsms.append(self.fsms[i])
self.fsms = new_fsms self.fsms = new_fsms
return self return self
class LogitBiasProcessor(LogitsProcessor):
"""
`LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their
corresponding bias values. Bias are applied to the logits during each forward pass.
Supports token IDs provided as strings (e.g., {"9707": -100}).
"""
def __init__(
self,
logit_biases: dict,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
):
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
# use _vocab_size or fallback to tokenizer.vocab_size if not available
self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size)
# Convert keys to integers and values to a list
token_ids = torch.tensor(
[int(k) for k in logit_biases.keys()], dtype=torch.long
)
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
# Create a tensor and directly copy bias values at the corresponding indices
self.bias_tensor = torch.zeros(self.vocab_size, dtype=torch.float)
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype))
return scores
class HeterogeneousLogitBiasProcessor(LogitsProcessor):
"""
Process logits with different logit biases for each sequence in the batch.
"""
def __init__(
self,
logit_biases: List[Optional[dict]],
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
):
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
self.tokenizer = tokenizer
self.logit_biases = logit_biases
# use _vocab_size or fallback to tokenizer.vocab_size if not available
self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size)
# Create batch_size x vocab_size bias matrix
self.bias_matrix = torch.zeros(
(len(logit_biases), self.vocab_size), dtype=torch.float, device=device
)
# for each logit bias dictionary, convert keys to integers and values to a list
for i, logit_bias in enumerate(logit_biases):
token_ids = torch.tensor(
[int(k) for k in logit_bias.keys()], dtype=torch.long
).to(device=device)
bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to(
device=device
)
# Create a tensor and directly copy bias values at the corresponding indices
self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype))
return scores
def filter(self, indices):
new_logit_biases = [self.logit_biases[i] for i in indices]
if not any(bias and len(bias) > 0 for bias in new_logit_biases):
return None
return HeterogeneousLogitBiasProcessor(
new_logit_biases, self.tokenizer, self.device
)

View File

@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor, FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor, GrammarLogitProcessor,
LogitBiasProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor,
@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import (
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor, HeterogeneousGrammarLogitProcessor,
HeterogeneousLogitBiasProcessor,
static_warper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
@ -38,6 +40,7 @@ class NextTokenChooser:
grammar: str = "", grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0, fsm_grammar_state: int = 0,
logit_bias: Optional[dict] = None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -57,6 +60,11 @@ class NextTokenChooser:
if grammar != "" if grammar != ""
else None else None
) )
self.logit_bias_processor = (
LogitBiasProcessor(logit_bias, tokenizer, device)
if logit_bias is not None and len(logit_bias) > 0
else None
)
self.tokenizer = tokenizer self.tokenizer = tokenizer
has_warpers = ( has_warpers = (
@ -87,6 +95,8 @@ class NextTokenChooser:
scores = self.frequency_processor(input_ids, scores) scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state) scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.logit_bias_processor is not None:
scores = self.logit_bias_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -125,6 +135,7 @@ class NextTokenChooser:
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar, grammar=pb.grammar,
grammar_type=pb.grammar_type, grammar_type=pb.grammar_type,
logit_bias=pb.logit_bias,
) )
@ -248,6 +259,7 @@ class HeterogeneousNextTokenChooser:
grammars: List[str], grammars: List[str],
grammar_types: List[int], grammar_types: List[int],
fsm_grammar_states=List[int], fsm_grammar_states=List[int],
logit_biases: List[Optional[dict]] = None,
): ):
warpers = [] warpers = []
@ -287,6 +299,12 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.logit_bias_processor = (
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
if any(logit_bias for logit_bias in logit_biases)
else None
)
if any(x != 1.0 for x in temperature): if any(x != 1.0 for x in temperature):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -322,6 +340,7 @@ class HeterogeneousNextTokenChooser:
self.fsm_grammar_states = fsm_grammar_states self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars self.grammars = grammars
self.grammar_types = grammar_types self.grammar_types = grammar_types
self.logit_biases = logit_biases
def __call__( def __call__(
self, self,
@ -353,6 +372,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.frequency_processor(input_ids, _scores) _scores = self.frequency_processor(input_ids, _scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states) _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
if self.logit_bias_processor is not None:
_scores = self.logit_bias_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
@ -444,6 +465,9 @@ class HeterogeneousNextTokenChooser:
if self.grammar_processor is not None: if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices) self.grammar_processor = self.grammar_processor.filter(indices)
if self.logit_bias_processor is not None:
self.logit_bias_processor = self.logit_bias_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
@ -453,6 +477,7 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
self.logit_biases = [self.logit_biases[i] for i in indices]
new_grammars = [] new_grammars = []
new_fsm_grammar_states = [] new_fsm_grammar_states = []
@ -500,6 +525,7 @@ class HeterogeneousNextTokenChooser:
fsm_grammar_states=( fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb) fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
), ),
logit_biases=[pb_.logit_bias for pb_ in pb],
) )