mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-07 18:02:07 +00:00
feat: support logit bias in chat request
This commit is contained in:
parent
7253be349a
commit
fae510b8f6
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
|||||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use pb::generate::v3::*;
|
use pb::generate::v3::*;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
@ -181,6 +182,7 @@ impl Client {
|
|||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
logit_bias: HashMap::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -5,6 +5,7 @@ use crate::{ClientError, Result};
|
|||||||
use crate::v3::{Chunk, InfoResponse, Input};
|
use crate::v3::{Chunk, InfoResponse, Input};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
|
use std::collections::HashMap;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use v3::client::{DecodeTimings, PrefillTimings};
|
use v3::client::{DecodeTimings, PrefillTimings};
|
||||||
@ -244,6 +245,7 @@ impl Health for ShardedClient {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
logit_bias: HashMap::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
|||||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use pb::generate::v3::*;
|
use pb::generate::v3::*;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
@ -181,6 +182,7 @@ impl Client {
|
|||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
logit_bias: HashMap::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -10,6 +10,7 @@ use crate::client::{
|
|||||||
use crate::client::{Chunk, InfoResponse, Input};
|
use crate::client::{Chunk, InfoResponse, Input};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
|
use std::collections::HashMap;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
@ -232,6 +233,7 @@ impl Health for ShardedClient {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
logit_bias: HashMap::new(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -5,6 +5,7 @@ use crate::client::{
|
|||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::max;
|
use std::cmp::max;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
@ -522,6 +523,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
|
|||||||
watermark: value.watermark,
|
watermark: value.watermark,
|
||||||
grammar,
|
grammar,
|
||||||
grammar_type: grammar_type.into(),
|
grammar_type: grammar_type.into(),
|
||||||
|
logit_bias: value
|
||||||
|
.logit_bias
|
||||||
|
.map(|bias| {
|
||||||
|
bias.into_iter()
|
||||||
|
.map(|(token, bias)| (token.to_string(), bias as i32))
|
||||||
|
.collect::<HashMap<String, i32>>()
|
||||||
|
})
|
||||||
|
.unwrap_or_default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,6 +47,7 @@ pub async fn run(
|
|||||||
watermark,
|
watermark,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
logit_bias: std::collections::HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, field_validator, ConfigDict
|
from pydantic import BaseModel, field_validator, ConfigDict
|
||||||
from typing import Optional, List, Union, Any
|
from typing import Optional, List, Union, Any, Dict
|
||||||
|
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
|
|||||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = None
|
||||||
# Bias values for token selection
|
# Bias values for token selection
|
||||||
logit_bias: Optional[List[float]] = None
|
logit_bias: Optional[Dict[str, int]] = None
|
||||||
# Whether to return log probabilities
|
# Whether to return log probabilities
|
||||||
logprobs: Optional[bool] = None
|
logprobs: Optional[bool] = None
|
||||||
# Number of most likely tokens to return at each position
|
# Number of most likely tokens to return at each position
|
||||||
|
@ -995,12 +995,12 @@
|
|||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"logit_bias": {
|
"logit_bias": {
|
||||||
"type": "array",
|
"type": "object",
|
||||||
"items": {
|
"description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
||||||
"type": "number",
|
"additionalProperties": {
|
||||||
"format": "float"
|
"type": "integer",
|
||||||
|
"format": "int32"
|
||||||
},
|
},
|
||||||
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"logprobs": {
|
"logprobs": {
|
||||||
@ -1589,6 +1589,17 @@
|
|||||||
"default": "null",
|
"default": "null",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"logit_bias": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
|
||||||
|
"default": "null",
|
||||||
|
"additionalProperties": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32"
|
||||||
|
},
|
||||||
|
"example": "{\"1923\": 100, \"1924\": -100}",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"max_new_tokens": {
|
"max_new_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "Hello! How can I help you today?",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1745337495,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.3-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 21,
|
||||||
|
"total_tokens": 31
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "¡Hola! ¿Cómo puedo ayudarte?",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1745337456,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.3-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 21,
|
||||||
|
"total_tokens": 31
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "Chat!",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1745337878,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.2.3-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 3,
|
||||||
|
"prompt_tokens": 25,
|
||||||
|
"total_tokens": 28
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1745337495,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.2.3-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
109
integration-tests/models/test_flash_logit_bias.py
Normal file
109
integration-tests/models/test_flash_logit_bias.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def logit_bias_model_handle(launcher):
|
||||||
|
with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def logit_bias_model(logit_bias_model_handle):
|
||||||
|
await logit_bias_model_handle.health(300)
|
||||||
|
return logit_bias_model_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot):
|
||||||
|
"""Test that setting negative bias on English tokens forces output to be in Spanish"""
|
||||||
|
response = await logit_bias_model.chat(
|
||||||
|
seed=42,
|
||||||
|
max_tokens=10,
|
||||||
|
logit_bias={"9707": -100}, # Bias against 'Hello' token
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "say Hello"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "¡Hola!" in response.choices[0].message.content
|
||||||
|
assert "Hello" not in response.choices[0].message.content
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_logit_bias_baseline(logit_bias_model, response_snapshot):
|
||||||
|
"""Test baseline behavior without logit bias for comparison"""
|
||||||
|
response = await logit_bias_model.chat(
|
||||||
|
seed=42,
|
||||||
|
max_tokens=10,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "say Hello"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Hello" in response.choices[0].message.content
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot):
|
||||||
|
"""Test applying bias to multiple tokens simultaneously"""
|
||||||
|
response = await logit_bias_model.chat(
|
||||||
|
seed=42,
|
||||||
|
max_tokens=15,
|
||||||
|
logit_bias={
|
||||||
|
"9707": -100, # Bias against 'Hello' token
|
||||||
|
"2880": -100, # Bias against 'hi' token
|
||||||
|
},
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Give me a one-word greeting"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Hello" not in response.choices[0].message.content.lower()
|
||||||
|
assert "hi" not in response.choices[0].message.content.lower()
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_logit_bias_streaming(logit_bias_model, response_snapshot):
|
||||||
|
"""Test logit bias works correctly with streaming enabled"""
|
||||||
|
responses = await logit_bias_model.chat(
|
||||||
|
seed=42,
|
||||||
|
max_tokens=10,
|
||||||
|
logit_bias={"9707": -100}, # Bias against 'Hello' token
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "say Hello"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
generated = ""
|
||||||
|
last_response = None
|
||||||
|
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
|
||||||
|
assert "¡Hola!" in generated
|
||||||
|
assert "Hello" not in generated
|
||||||
|
assert last_response == response_snapshot
|
@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
|
|||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::IntoPyDict;
|
use pyo3::types::IntoPyDict;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
use tokenizers::Encoding;
|
use tokenizers::Encoding;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub adapter_id: Option<String>,
|
pub adapter_id: Option<String>,
|
||||||
|
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion.
|
||||||
|
/// Accepts a hash map that maps token strings to an associated bias value.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = "{\"1923\": 100, \"1924\": -100}"
|
||||||
|
)]
|
||||||
|
pub logit_bias: Option<HashMap<String, i32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
|
logit_bias: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
#[serde(try_from = "PromptDeserializer")]
|
#[serde(try_from = "PromptDeserializer")]
|
||||||
pub struct Prompt(pub Vec<String>);
|
pub struct Prompt(pub Vec<String>);
|
||||||
@ -841,14 +852,13 @@ pub(crate) struct ChatRequest {
|
|||||||
#[schema(example = "1.0")]
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||||
/// result in a ban or exclusive selection of the relevant token.
|
/// result in a ban or exclusive selection of the relevant token.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub logit_bias: Option<Vec<f32>>,
|
pub logit_bias: Option<HashMap<String, i32>>,
|
||||||
|
|
||||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
/// output token returned in the content of message.
|
/// output token returned in the content of message.
|
||||||
@ -954,6 +964,7 @@ impl ChatRequest {
|
|||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
top_p,
|
top_p,
|
||||||
top_logprobs,
|
top_logprobs,
|
||||||
|
logit_bias,
|
||||||
..
|
..
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
@ -1029,6 +1040,7 @@ impl ChatRequest {
|
|||||||
top_n_tokens: top_logprobs,
|
top_n_tokens: top_logprobs,
|
||||||
grammar,
|
grammar,
|
||||||
adapter_id: model.filter(|m| *m != "tgi"),
|
adapter_id: model.filter(|m| *m != "tgi"),
|
||||||
|
logit_bias,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
using_tools,
|
using_tools,
|
||||||
|
@ -798,6 +798,7 @@ pub(crate) async fn completions(
|
|||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||||
|
logit_bias: None,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@ -1191,6 +1192,7 @@ pub(crate) async fn chat_completions(
|
|||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
chat.clone().try_into_generate(&infer)?;
|
chat.clone().try_into_generate(&infer)?;
|
||||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||||
|
println!("ChatRequest: {:#?}", generate_request);
|
||||||
let logprobs = logprobs.unwrap_or_default();
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
|
|
||||||
// extract model id from request if specified
|
// extract model id from request if specified
|
||||||
|
@ -420,6 +420,18 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
|
logit_bias: Some(
|
||||||
|
request
|
||||||
|
.parameters
|
||||||
|
.logit_bias
|
||||||
|
.iter()
|
||||||
|
.flat_map(|bias| {
|
||||||
|
bias.iter()
|
||||||
|
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let stopping_parameters = ValidStoppingParameters {
|
let stopping_parameters = ValidStoppingParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
@ -902,6 +914,8 @@ pub struct ValidParameters {
|
|||||||
pub watermark: bool,
|
pub watermark: bool,
|
||||||
/// / grammar (applied if not empty)
|
/// / grammar (applied if not empty)
|
||||||
pub grammar: Option<ValidGrammar>,
|
pub grammar: Option<ValidGrammar>,
|
||||||
|
/// / logit bias
|
||||||
|
pub logit_bias: Option<Vec<(u32, f32)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -623,3 +623,65 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
new_fsms.append(self.fsms[i])
|
new_fsms.append(self.fsms[i])
|
||||||
self.fsms = new_fsms
|
self.fsms = new_fsms
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousLogitBiasProcessor:
|
||||||
|
"""Process logits with different logit biases for each sequence in the batch."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logit_biases: List[Optional[dict]],
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.logit_biases = logit_biases
|
||||||
|
self.batch_size = len(logit_biases)
|
||||||
|
|
||||||
|
# Pre-compute token IDs for each token string
|
||||||
|
self.token_id_mapping = {}
|
||||||
|
|
||||||
|
# Create a mapping of indices that have logit biases
|
||||||
|
self.indices_with_biases = {
|
||||||
|
i: bias_dict
|
||||||
|
for i, bias_dict in enumerate(self.logit_biases)
|
||||||
|
if bias_dict is not None and len(bias_dict) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
# If no indices have biases, return scores unchanged
|
||||||
|
if not self.indices_with_biases:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# For each index with a bias, apply the bias to the corresponding scores
|
||||||
|
for i, bias_dict in self.indices_with_biases.items():
|
||||||
|
for token_str, bias_value in bias_dict.items():
|
||||||
|
# Get token ID, either from cache or by computing it
|
||||||
|
if token_str not in self.token_id_mapping:
|
||||||
|
if token_str.isdigit():
|
||||||
|
# If the token string is already a numeric ID
|
||||||
|
token_id = int(token_str)
|
||||||
|
else:
|
||||||
|
# Otherwise, use the tokenizer to get the ID
|
||||||
|
tokens = self.tokenizer.encode(
|
||||||
|
token_str, add_special_tokens=False
|
||||||
|
)
|
||||||
|
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
||||||
|
|
||||||
|
self.token_id_mapping[token_str] = token_id
|
||||||
|
|
||||||
|
token_id = self.token_id_mapping[token_str]
|
||||||
|
|
||||||
|
# Apply bias if token ID is valid
|
||||||
|
if 0 <= token_id < scores.size(-1):
|
||||||
|
scores[i, token_id] += bias_value
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def filter(self, indices: List[int]):
|
||||||
|
"""Keep only the logit biases for the specified indices."""
|
||||||
|
new_logit_biases = [self.logit_biases[i] for i in indices]
|
||||||
|
return HeterogeneousLogitBiasProcessor(
|
||||||
|
new_logit_biases, self.tokenizer, self.device
|
||||||
|
)
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import (
|
|||||||
HeterogeneousTopPLogitsWarper,
|
HeterogeneousTopPLogitsWarper,
|
||||||
HeterogeneousTypicalLogitsWarper,
|
HeterogeneousTypicalLogitsWarper,
|
||||||
HeterogeneousGrammarLogitProcessor,
|
HeterogeneousGrammarLogitProcessor,
|
||||||
|
HeterogeneousLogitBiasProcessor,
|
||||||
static_warper,
|
static_warper,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
@ -38,6 +39,7 @@ class NextTokenChooser:
|
|||||||
grammar: str = "",
|
grammar: str = "",
|
||||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||||
fsm_grammar_state: int = 0,
|
fsm_grammar_state: int = 0,
|
||||||
|
logit_bias: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -58,6 +60,7 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.logit_bias = logit_bias
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
@ -87,6 +90,8 @@ class NextTokenChooser:
|
|||||||
scores = self.frequency_processor(input_ids, scores)
|
scores = self.frequency_processor(input_ids, scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||||
|
if self.logit_bias_processor is not None:
|
||||||
|
scores = self.logit_bias_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
@ -125,6 +130,7 @@ class NextTokenChooser:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
grammar_type=pb.grammar_type,
|
grammar_type=pb.grammar_type,
|
||||||
|
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -248,9 +254,14 @@ class HeterogeneousNextTokenChooser:
|
|||||||
grammars: List[str],
|
grammars: List[str],
|
||||||
grammar_types: List[int],
|
grammar_types: List[int],
|
||||||
fsm_grammar_states=List[int],
|
fsm_grammar_states=List[int],
|
||||||
|
logit_biases: List[Optional[dict]] = None,
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
|
# Initialize with empty logit biases if none provided
|
||||||
|
if logit_biases is None:
|
||||||
|
logit_biases = [None] * len(do_sample)
|
||||||
|
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
HeterogeneousProcessorWrapper(
|
HeterogeneousProcessorWrapper(
|
||||||
{
|
{
|
||||||
@ -287,6 +298,12 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.logit_bias_processor = (
|
||||||
|
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
|
||||||
|
if any([bias is not None and len(bias) > 0 for bias in logit_biases])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if any(x != 1.0 for x in temperature):
|
if any(x != 1.0 for x in temperature):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
@ -322,6 +339,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.fsm_grammar_states = fsm_grammar_states
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
self.grammars = grammars
|
self.grammars = grammars
|
||||||
self.grammar_types = grammar_types
|
self.grammar_types = grammar_types
|
||||||
|
self.logit_biases = logit_biases
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -353,6 +371,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
_scores = self.frequency_processor(input_ids, _scores)
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||||
|
if self.logit_bias_processor is not None:
|
||||||
|
_scores = self.logit_bias_processor(input_ids, _scores)
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
@ -444,6 +464,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
self.grammar_processor = self.grammar_processor.filter(indices)
|
self.grammar_processor = self.grammar_processor.filter(indices)
|
||||||
|
|
||||||
|
if self.logit_bias_processor is not None:
|
||||||
|
self.logit_bias_processor = self.logit_bias_processor.filter(indices)
|
||||||
|
|
||||||
filtered_warpers = []
|
filtered_warpers = []
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
filtered_warper = warper.filter(indices)
|
filtered_warper = warper.filter(indices)
|
||||||
@ -453,6 +476,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
self.logit_biases = [self.logit_biases[i] for i in indices]
|
||||||
|
|
||||||
new_grammars = []
|
new_grammars = []
|
||||||
new_fsm_grammar_states = []
|
new_fsm_grammar_states = []
|
||||||
@ -500,6 +524,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
fsm_grammar_states=(
|
fsm_grammar_states=(
|
||||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||||
),
|
),
|
||||||
|
logit_biases=[
|
||||||
|
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user