mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Merge 551ee3a365
into 3752143b39
This commit is contained in:
commit
0381aba864
26
.github/workflows/client-tests.yaml
vendored
26
.github/workflows/client-tests.yaml
vendored
@ -1,26 +0,0 @@
|
||||
name: Python Client Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/client-tests.yaml"
|
||||
- "clients/python/**"
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install
|
||||
run: |
|
||||
cd clients/python && pip install .
|
||||
- name: Run tests
|
||||
run: |
|
||||
pip install pytest pytest-asyncio
|
||||
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||
make python-client-tests
|
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
@ -181,6 +182,7 @@ impl Client {
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -5,6 +5,7 @@ use crate::{ClientError, Result};
|
||||
use crate::v3::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use std::collections::HashMap;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use v3::client::{DecodeTimings, PrefillTimings};
|
||||
@ -244,6 +245,7 @@ impl Health for ShardedClient {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -429,6 +429,7 @@ mod tests {
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: None,
|
||||
logit_bias: None,
|
||||
},
|
||||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
@ -181,6 +182,7 @@ impl Client {
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -10,6 +10,7 @@ use crate::client::{
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use std::collections::HashMap;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
@ -232,6 +233,7 @@ impl Health for ShardedClient {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -5,6 +5,7 @@ use crate::client::{
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::max;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
@ -542,6 +543,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
|
||||
watermark: value.watermark,
|
||||
grammar,
|
||||
grammar_type: grammar_type.into(),
|
||||
logit_bias: value
|
||||
.logit_bias
|
||||
.map(|bias| {
|
||||
bias.into_iter()
|
||||
.map(|(token, bias)| (token.to_string(), bias as i32))
|
||||
.collect::<HashMap<String, i32>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -588,6 +597,7 @@ mod tests {
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: None,
|
||||
logit_bias: None,
|
||||
},
|
||||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -47,6 +47,7 @@ pub async fn run(
|
||||
watermark,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
|
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, field_validator, ConfigDict
|
||||
from typing import Optional, List, Union, Any
|
||||
from typing import Optional, List, Union, Any, Dict
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Bias values for token selection
|
||||
logit_bias: Optional[List[float]] = None
|
||||
logit_bias: Optional[Dict[str, int]] = None
|
||||
# Whether to return log probabilities
|
||||
logprobs: Optional[bool] = None
|
||||
# Number of most likely tokens to return at each position
|
||||
|
@ -995,12 +995,12 @@
|
||||
"nullable": true
|
||||
},
|
||||
"logit_bias": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number",
|
||||
"format": "float"
|
||||
"type": "object",
|
||||
"description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
},
|
||||
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
||||
"nullable": true
|
||||
},
|
||||
"logprobs": {
|
||||
@ -1589,6 +1589,17 @@
|
||||
"default": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"logit_bias": {
|
||||
"type": "object",
|
||||
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
|
||||
"default": "null",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
},
|
||||
"example": "{\"1923\": 100, \"1924\": -100}",
|
||||
"nullable": true
|
||||
},
|
||||
"max_new_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1745337495,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 21,
|
||||
"total_tokens": 31
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "¡Hola! ¿Cómo puedo ayudarte?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1746486174,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 21,
|
||||
"total_tokens": 31
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Chat!",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1746486174,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 3,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 28
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1746486174,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": null
|
||||
}
|
109
integration-tests/models/test_flash_logit_bias.py
Normal file
109
integration-tests/models/test_flash_logit_bias.py
Normal file
@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def logit_bias_model_handle(launcher):
|
||||
with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def logit_bias_model(logit_bias_model_handle):
|
||||
await logit_bias_model_handle.health(300)
|
||||
return logit_bias_model_handle.client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot):
|
||||
"""Test that setting negative bias on English tokens forces output to be in Spanish"""
|
||||
response = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=10,
|
||||
logit_bias={"9707": -100}, # Bias against 'Hello' token
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "say Hello"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "¡Hola!" in response.choices[0].message.content
|
||||
assert "Hello" not in response.choices[0].message.content
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_baseline(logit_bias_model, response_snapshot):
|
||||
"""Test baseline behavior without logit bias for comparison"""
|
||||
response = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=10,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "say Hello"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "Hello" in response.choices[0].message.content
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot):
|
||||
"""Test applying bias to multiple tokens simultaneously"""
|
||||
response = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=15,
|
||||
logit_bias={
|
||||
"9707": -100, # Bias against 'Hello' token
|
||||
"2880": -100, # Bias against 'hi' token
|
||||
},
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Give me a one-word greeting"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "Hello" not in response.choices[0].message.content.lower()
|
||||
assert "hi" not in response.choices[0].message.content.lower()
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_streaming(logit_bias_model, response_snapshot):
|
||||
"""Test logit bias works correctly with streaming enabled"""
|
||||
responses = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=10,
|
||||
logit_bias={"9707": -100}, # Bias against 'Hello' token
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "say Hello"},
|
||||
],
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
generated = ""
|
||||
last_response = None
|
||||
|
||||
async for response in responses:
|
||||
count += 1
|
||||
generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
|
||||
assert "¡Hola!" in generated
|
||||
assert "Hello" not in generated
|
||||
assert last_response == response_snapshot
|
@ -104,6 +104,8 @@ message NextTokenChooserParameters {
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
/// logit bias dictionary mapping token string to bias value
|
||||
map<string, int32> logit_bias = 12;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::Encoding;
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub adapter_id: Option<String>,
|
||||
|
||||
/// Modify the likelihood of specified tokens appearing in the completion.
|
||||
/// Accepts a hash map that maps token strings to an associated bias value.
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = "{\"1923\": 100, \"1924\": -100}"
|
||||
)]
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: None,
|
||||
logit_bias: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[serde(try_from = "PromptDeserializer")]
|
||||
pub struct Prompt(pub Vec<String>);
|
||||
@ -841,14 +852,13 @@ pub(crate) struct ChatRequest {
|
||||
#[schema(example = "1.0")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// UNUSED
|
||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||
/// result in a ban or exclusive selection of the relevant token.
|
||||
#[serde(default)]
|
||||
pub logit_bias: Option<Vec<f32>>,
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
|
||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||
/// output token returned in the content of message.
|
||||
@ -954,6 +964,7 @@ impl ChatRequest {
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
logit_bias,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@ -1029,6 +1040,7 @@ impl ChatRequest {
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi"),
|
||||
logit_bias,
|
||||
},
|
||||
},
|
||||
using_tools,
|
||||
|
@ -798,6 +798,7 @@ pub(crate) async fn completions(
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||
logit_bias: None,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
@ -34,6 +34,7 @@ pub struct Validation {
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
disable_grammar_support: bool,
|
||||
vocab_size: u32,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: mpsc::UnboundedSender<TokenizerRequest>,
|
||||
}
|
||||
@ -88,6 +89,19 @@ impl Validation {
|
||||
validation_sender
|
||||
};
|
||||
|
||||
let vocab_size = match &tokenizer {
|
||||
Tokenizer::Python { tokenizer_name, .. } => {
|
||||
warn!(
|
||||
"Tokenizer {} is not supported for validation",
|
||||
tokenizer_name
|
||||
);
|
||||
0
|
||||
}
|
||||
Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false),
|
||||
}
|
||||
.try_into()
|
||||
.unwrap_or(0);
|
||||
|
||||
Self {
|
||||
max_best_of,
|
||||
sender,
|
||||
@ -96,6 +110,7 @@ impl Validation {
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
vocab_size,
|
||||
}
|
||||
}
|
||||
|
||||
@ -409,6 +424,37 @@ impl Validation {
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Validate logit bias and convert to a vector of (token_id, bias_value)
|
||||
let logit_bias = request
|
||||
.parameters
|
||||
.logit_bias
|
||||
.as_ref()
|
||||
.filter(|bias_map| !bias_map.is_empty())
|
||||
.map(|bias_map| {
|
||||
bias_map
|
||||
.iter()
|
||||
.map(|(token_str, &bias_value)| {
|
||||
let token_id: u32 = token_str.parse().map_err(|_| {
|
||||
ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {token_str} is not a valid number."
|
||||
))
|
||||
})?;
|
||||
|
||||
if token_id >= self.vocab_size {
|
||||
return Err(ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {token_id} is out of range (0..{}).",
|
||||
self.vocab_size - 1
|
||||
)));
|
||||
}
|
||||
|
||||
Ok((token_id, bias_value as f32))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
// convert Option<Result<T, E>> to Result<Option<T>, E> to throw
|
||||
// if any of the token IDs are invalid
|
||||
.transpose()?;
|
||||
|
||||
let parameters = ValidParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -420,6 +466,7 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
logit_bias,
|
||||
};
|
||||
let stopping_parameters = ValidStoppingParameters {
|
||||
max_new_tokens,
|
||||
@ -902,6 +949,8 @@ pub struct ValidParameters {
|
||||
pub watermark: bool,
|
||||
/// / grammar (applied if not empty)
|
||||
pub grammar: Option<ValidGrammar>,
|
||||
/// / logit bias
|
||||
pub logit_bias: Option<Vec<(u32, f32)>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -997,6 +1046,8 @@ pub enum ValidationError {
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
#[error("{0} modality is not supported")]
|
||||
UnsupportedModality(&'static str),
|
||||
#[error("logit_bias is not valid: {0}")]
|
||||
LogitBiasInvalid(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -725,6 +725,15 @@ class VlmCausalLM(FlashCausalLM):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.config.vocab_size != self.tokenizer.vocab_size:
|
||||
logger.warning(
|
||||
f"Tokenizer vocab size {self.tokenizer.vocab_size} does not match model vocab size {self.config.vocab_size}. Updating tokenizer vocab size."
|
||||
)
|
||||
# TODO: HUGE HACK! This is a workaround to update the vocab size
|
||||
# in the tokenizer. When the tokenizer is updated within the model
|
||||
# the vocab size is not updated in the tokenizer.
|
||||
self.tokenizer._vocab_size = self.config.vocab_size
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return self.batch_class
|
||||
|
@ -623,3 +623,85 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
new_fsms.append(self.fsms[i])
|
||||
self.fsms = new_fsms
|
||||
return self
|
||||
|
||||
|
||||
class LogitBiasProcessor(LogitsProcessor):
|
||||
"""
|
||||
`LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their
|
||||
corresponding bias values. Bias are applied to the logits during each forward pass.
|
||||
|
||||
Supports token IDs provided as strings (e.g., {"9707": -100}).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logit_biases: dict,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
):
|
||||
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
||||
|
||||
# use _vocab_size or fallback to tokenizer.vocab_size if not available
|
||||
self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size)
|
||||
|
||||
# Convert keys to integers and values to a list
|
||||
token_ids = torch.tensor(
|
||||
[int(k) for k in logit_biases.keys()], dtype=torch.long
|
||||
)
|
||||
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
|
||||
|
||||
# Create a tensor and directly copy bias values at the corresponding indices
|
||||
self.bias_tensor = torch.zeros(self.vocab_size, dtype=torch.float)
|
||||
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype))
|
||||
return scores
|
||||
|
||||
|
||||
class HeterogeneousLogitBiasProcessor(LogitsProcessor):
|
||||
"""
|
||||
Process logits with different logit biases for each sequence in the batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logit_biases: List[Optional[dict]],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
):
|
||||
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_biases = logit_biases
|
||||
|
||||
# use _vocab_size or fallback to tokenizer.vocab_size if not available
|
||||
self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size)
|
||||
|
||||
# Create batch_size x vocab_size bias matrix
|
||||
self.bias_matrix = torch.zeros(
|
||||
(len(logit_biases), self.vocab_size), dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
# for each logit bias dictionary, convert keys to integers and values to a list
|
||||
for i, logit_bias in enumerate(logit_biases):
|
||||
token_ids = torch.tensor(
|
||||
[int(k) for k in logit_bias.keys()], dtype=torch.long
|
||||
).to(device=device)
|
||||
bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to(
|
||||
device=device
|
||||
)
|
||||
# Create a tensor and directly copy bias values at the corresponding indices
|
||||
self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype))
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
new_logit_biases = [self.logit_biases[i] for i in indices]
|
||||
if not any(bias and len(bias) > 0 for bias in new_logit_biases):
|
||||
return None
|
||||
return HeterogeneousLogitBiasProcessor(
|
||||
new_logit_biases, self.tokenizer, self.device
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
|
||||
from text_generation_server.utils.logits_process import (
|
||||
FrequencyPenaltyLogitsProcessor,
|
||||
GrammarLogitProcessor,
|
||||
LogitBiasProcessor,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||
@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import (
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousGrammarLogitProcessor,
|
||||
HeterogeneousLogitBiasProcessor,
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
@ -38,6 +40,7 @@ class NextTokenChooser:
|
||||
grammar: str = "",
|
||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||
fsm_grammar_state: int = 0,
|
||||
logit_bias: Optional[dict] = None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -57,6 +60,11 @@ class NextTokenChooser:
|
||||
if grammar != ""
|
||||
else None
|
||||
)
|
||||
self.logit_bias_processor = (
|
||||
LogitBiasProcessor(logit_bias, tokenizer, device)
|
||||
if logit_bias is not None and len(logit_bias) > 0
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
has_warpers = (
|
||||
@ -87,6 +95,8 @@ class NextTokenChooser:
|
||||
scores = self.frequency_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||
if self.logit_bias_processor is not None:
|
||||
scores = self.logit_bias_processor(input_ids, scores)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
@ -125,6 +135,7 @@ class NextTokenChooser:
|
||||
tokenizer=tokenizer,
|
||||
grammar=pb.grammar,
|
||||
grammar_type=pb.grammar_type,
|
||||
logit_bias=pb.logit_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -248,6 +259,7 @@ class HeterogeneousNextTokenChooser:
|
||||
grammars: List[str],
|
||||
grammar_types: List[int],
|
||||
fsm_grammar_states=List[int],
|
||||
logit_biases: List[Optional[dict]] = None,
|
||||
):
|
||||
warpers = []
|
||||
|
||||
@ -287,6 +299,12 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
self.logit_bias_processor = (
|
||||
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
|
||||
if any(logit_bias for logit_bias in logit_biases)
|
||||
else None
|
||||
)
|
||||
|
||||
if any(x != 1.0 for x in temperature):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
@ -322,6 +340,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.fsm_grammar_states = fsm_grammar_states
|
||||
self.grammars = grammars
|
||||
self.grammar_types = grammar_types
|
||||
self.logit_biases = logit_biases
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -353,6 +372,8 @@ class HeterogeneousNextTokenChooser:
|
||||
_scores = self.frequency_processor(input_ids, _scores)
|
||||
if self.grammar_processor is not None:
|
||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||
if self.logit_bias_processor is not None:
|
||||
_scores = self.logit_bias_processor(input_ids, _scores)
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
_next_ids = self.choice(_scores)
|
||||
@ -444,6 +465,9 @@ class HeterogeneousNextTokenChooser:
|
||||
if self.grammar_processor is not None:
|
||||
self.grammar_processor = self.grammar_processor.filter(indices)
|
||||
|
||||
if self.logit_bias_processor is not None:
|
||||
self.logit_bias_processor = self.logit_bias_processor.filter(indices)
|
||||
|
||||
filtered_warpers = []
|
||||
for warper in self.warpers:
|
||||
filtered_warper = warper.filter(indices)
|
||||
@ -453,6 +477,7 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
self.logit_biases = [self.logit_biases[i] for i in indices]
|
||||
|
||||
new_grammars = []
|
||||
new_fsm_grammar_states = []
|
||||
@ -500,6 +525,7 @@ class HeterogeneousNextTokenChooser:
|
||||
fsm_grammar_states=(
|
||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||
),
|
||||
logit_biases=[pb_.logit_bias for pb_ in pb],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user