mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-07 18:02:07 +00:00
feat: support logit bias in chat request
This commit is contained in:
parent
7253be349a
commit
fae510b8f6
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
@ -181,6 +182,7 @@ impl Client {
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -5,6 +5,7 @@ use crate::{ClientError, Result};
|
||||
use crate::v3::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use std::collections::HashMap;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use v3::client::{DecodeTimings, PrefillTimings};
|
||||
@ -244,6 +245,7 @@ impl Health for ShardedClient {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
@ -181,6 +182,7 @@ impl Client {
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -10,6 +10,7 @@ use crate::client::{
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use std::collections::HashMap;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
@ -232,6 +233,7 @@ impl Health for ShardedClient {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: HashMap::new(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -5,6 +5,7 @@ use crate::client::{
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::max;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
@ -522,6 +523,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
|
||||
watermark: value.watermark,
|
||||
grammar,
|
||||
grammar_type: grammar_type.into(),
|
||||
logit_bias: value
|
||||
.logit_bias
|
||||
.map(|bias| {
|
||||
bias.into_iter()
|
||||
.map(|(token, bias)| (token.to_string(), bias as i32))
|
||||
.collect::<HashMap<String, i32>>()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -47,6 +47,7 @@ pub async fn run(
|
||||
watermark,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
logit_bias: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
|
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, field_validator, ConfigDict
|
||||
from typing import Optional, List, Union, Any
|
||||
from typing import Optional, List, Union, Any, Dict
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
|
||||
# decreasing the model's likelihood to repeat the same line verbatim.
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Bias values for token selection
|
||||
logit_bias: Optional[List[float]] = None
|
||||
logit_bias: Optional[Dict[str, int]] = None
|
||||
# Whether to return log probabilities
|
||||
logprobs: Optional[bool] = None
|
||||
# Number of most likely tokens to return at each position
|
||||
|
@ -995,12 +995,12 @@
|
||||
"nullable": true
|
||||
},
|
||||
"logit_bias": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number",
|
||||
"format": "float"
|
||||
"type": "object",
|
||||
"description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
},
|
||||
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
|
||||
"nullable": true
|
||||
},
|
||||
"logprobs": {
|
||||
@ -1589,6 +1589,17 @@
|
||||
"default": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"logit_bias": {
|
||||
"type": "object",
|
||||
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
|
||||
"default": "null",
|
||||
"additionalProperties": {
|
||||
"type": "integer",
|
||||
"format": "int32"
|
||||
},
|
||||
"example": "{\"1923\": 100, \"1924\": -100}",
|
||||
"nullable": true
|
||||
},
|
||||
"max_new_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1745337495,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 21,
|
||||
"total_tokens": 31
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "¡Hola! ¿Cómo puedo ayudarte?",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1745337456,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 21,
|
||||
"total_tokens": 31
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Chat!",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1745337878,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 3,
|
||||
"prompt_tokens": 25,
|
||||
"total_tokens": 28
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1745337495,
|
||||
"id": "",
|
||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": null
|
||||
}
|
109
integration-tests/models/test_flash_logit_bias.py
Normal file
109
integration-tests/models/test_flash_logit_bias.py
Normal file
@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def logit_bias_model_handle(launcher):
|
||||
with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def logit_bias_model(logit_bias_model_handle):
|
||||
await logit_bias_model_handle.health(300)
|
||||
return logit_bias_model_handle.client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot):
|
||||
"""Test that setting negative bias on English tokens forces output to be in Spanish"""
|
||||
response = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=10,
|
||||
logit_bias={"9707": -100}, # Bias against 'Hello' token
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "say Hello"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "¡Hola!" in response.choices[0].message.content
|
||||
assert "Hello" not in response.choices[0].message.content
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_baseline(logit_bias_model, response_snapshot):
|
||||
"""Test baseline behavior without logit bias for comparison"""
|
||||
response = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=10,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "say Hello"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "Hello" in response.choices[0].message.content
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot):
|
||||
"""Test applying bias to multiple tokens simultaneously"""
|
||||
response = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=15,
|
||||
logit_bias={
|
||||
"9707": -100, # Bias against 'Hello' token
|
||||
"2880": -100, # Bias against 'hi' token
|
||||
},
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Give me a one-word greeting"},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert "Hello" not in response.choices[0].message.content.lower()
|
||||
assert "hi" not in response.choices[0].message.content.lower()
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_logit_bias_streaming(logit_bias_model, response_snapshot):
|
||||
"""Test logit bias works correctly with streaming enabled"""
|
||||
responses = await logit_bias_model.chat(
|
||||
seed=42,
|
||||
max_tokens=10,
|
||||
logit_bias={"9707": -100}, # Bias against 'Hello' token
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "say Hello"},
|
||||
],
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
generated = ""
|
||||
last_response = None
|
||||
|
||||
async for response in responses:
|
||||
count += 1
|
||||
generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
|
||||
assert "¡Hola!" in generated
|
||||
assert "Hello" not in generated
|
||||
assert last_response == response_snapshot
|
@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::Encoding;
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub adapter_id: Option<String>,
|
||||
|
||||
/// Modify the likelihood of specified tokens appearing in the completion.
|
||||
/// Accepts a hash map that maps token strings to an associated bias value.
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
default = "null",
|
||||
example = "{\"1923\": 100, \"1924\": -100}"
|
||||
)]
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: None,
|
||||
logit_bias: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[serde(try_from = "PromptDeserializer")]
|
||||
pub struct Prompt(pub Vec<String>);
|
||||
@ -841,14 +852,13 @@ pub(crate) struct ChatRequest {
|
||||
#[schema(example = "1.0")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// UNUSED
|
||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||
/// result in a ban or exclusive selection of the relevant token.
|
||||
#[serde(default)]
|
||||
pub logit_bias: Option<Vec<f32>>,
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
|
||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||
/// output token returned in the content of message.
|
||||
@ -954,6 +964,7 @@ impl ChatRequest {
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
logit_bias,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@ -1029,6 +1040,7 @@ impl ChatRequest {
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi"),
|
||||
logit_bias,
|
||||
},
|
||||
},
|
||||
using_tools,
|
||||
|
@ -798,6 +798,7 @@ pub(crate) async fn completions(
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||
logit_bias: None,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
@ -1191,6 +1192,7 @@ pub(crate) async fn chat_completions(
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.clone().try_into_generate(&infer)?;
|
||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||
println!("ChatRequest: {:#?}", generate_request);
|
||||
let logprobs = logprobs.unwrap_or_default();
|
||||
|
||||
// extract model id from request if specified
|
||||
|
@ -420,6 +420,18 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
logit_bias: Some(
|
||||
request
|
||||
.parameters
|
||||
.logit_bias
|
||||
.iter()
|
||||
.flat_map(|bias| {
|
||||
bias.iter()
|
||||
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
};
|
||||
let stopping_parameters = ValidStoppingParameters {
|
||||
max_new_tokens,
|
||||
@ -902,6 +914,8 @@ pub struct ValidParameters {
|
||||
pub watermark: bool,
|
||||
/// / grammar (applied if not empty)
|
||||
pub grammar: Option<ValidGrammar>,
|
||||
/// / logit bias
|
||||
pub logit_bias: Option<Vec<(u32, f32)>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -623,3 +623,65 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
new_fsms.append(self.fsms[i])
|
||||
self.fsms = new_fsms
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousLogitBiasProcessor:
|
||||
"""Process logits with different logit biases for each sequence in the batch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logit_biases: List[Optional[dict]],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
):
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_biases = logit_biases
|
||||
self.batch_size = len(logit_biases)
|
||||
|
||||
# Pre-compute token IDs for each token string
|
||||
self.token_id_mapping = {}
|
||||
|
||||
# Create a mapping of indices that have logit biases
|
||||
self.indices_with_biases = {
|
||||
i: bias_dict
|
||||
for i, bias_dict in enumerate(self.logit_biases)
|
||||
if bias_dict is not None and len(bias_dict) > 0
|
||||
}
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# If no indices have biases, return scores unchanged
|
||||
if not self.indices_with_biases:
|
||||
return scores
|
||||
|
||||
# For each index with a bias, apply the bias to the corresponding scores
|
||||
for i, bias_dict in self.indices_with_biases.items():
|
||||
for token_str, bias_value in bias_dict.items():
|
||||
# Get token ID, either from cache or by computing it
|
||||
if token_str not in self.token_id_mapping:
|
||||
if token_str.isdigit():
|
||||
# If the token string is already a numeric ID
|
||||
token_id = int(token_str)
|
||||
else:
|
||||
# Otherwise, use the tokenizer to get the ID
|
||||
tokens = self.tokenizer.encode(
|
||||
token_str, add_special_tokens=False
|
||||
)
|
||||
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
||||
|
||||
self.token_id_mapping[token_str] = token_id
|
||||
|
||||
token_id = self.token_id_mapping[token_str]
|
||||
|
||||
# Apply bias if token ID is valid
|
||||
if 0 <= token_id < scores.size(-1):
|
||||
scores[i, token_id] += bias_value
|
||||
|
||||
return scores
|
||||
|
||||
def filter(self, indices: List[int]):
|
||||
"""Keep only the logit biases for the specified indices."""
|
||||
new_logit_biases = [self.logit_biases[i] for i in indices]
|
||||
return HeterogeneousLogitBiasProcessor(
|
||||
new_logit_biases, self.tokenizer, self.device
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ from text_generation_server.utils.logits_process import (
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousGrammarLogitProcessor,
|
||||
HeterogeneousLogitBiasProcessor,
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
@ -38,6 +39,7 @@ class NextTokenChooser:
|
||||
grammar: str = "",
|
||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||
fsm_grammar_state: int = 0,
|
||||
logit_bias: Optional[dict] = None,
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -58,6 +60,7 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
@ -87,6 +90,8 @@ class NextTokenChooser:
|
||||
scores = self.frequency_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||
if self.logit_bias_processor is not None:
|
||||
scores = self.logit_bias_processor(input_ids, scores)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
@ -125,6 +130,7 @@ class NextTokenChooser:
|
||||
tokenizer=tokenizer,
|
||||
grammar=pb.grammar,
|
||||
grammar_type=pb.grammar_type,
|
||||
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
|
||||
)
|
||||
|
||||
|
||||
@ -248,9 +254,14 @@ class HeterogeneousNextTokenChooser:
|
||||
grammars: List[str],
|
||||
grammar_types: List[int],
|
||||
fsm_grammar_states=List[int],
|
||||
logit_biases: List[Optional[dict]] = None,
|
||||
):
|
||||
warpers = []
|
||||
|
||||
# Initialize with empty logit biases if none provided
|
||||
if logit_biases is None:
|
||||
logit_biases = [None] * len(do_sample)
|
||||
|
||||
self.watermark_processor = (
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
@ -287,6 +298,12 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
self.logit_bias_processor = (
|
||||
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
|
||||
if any([bias is not None and len(bias) > 0 for bias in logit_biases])
|
||||
else None
|
||||
)
|
||||
|
||||
if any(x != 1.0 for x in temperature):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
@ -322,6 +339,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.fsm_grammar_states = fsm_grammar_states
|
||||
self.grammars = grammars
|
||||
self.grammar_types = grammar_types
|
||||
self.logit_biases = logit_biases
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -353,6 +371,8 @@ class HeterogeneousNextTokenChooser:
|
||||
_scores = self.frequency_processor(input_ids, _scores)
|
||||
if self.grammar_processor is not None:
|
||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||
if self.logit_bias_processor is not None:
|
||||
_scores = self.logit_bias_processor(input_ids, _scores)
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
_next_ids = self.choice(_scores)
|
||||
@ -444,6 +464,9 @@ class HeterogeneousNextTokenChooser:
|
||||
if self.grammar_processor is not None:
|
||||
self.grammar_processor = self.grammar_processor.filter(indices)
|
||||
|
||||
if self.logit_bias_processor is not None:
|
||||
self.logit_bias_processor = self.logit_bias_processor.filter(indices)
|
||||
|
||||
filtered_warpers = []
|
||||
for warper in self.warpers:
|
||||
filtered_warper = warper.filter(indices)
|
||||
@ -453,6 +476,7 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
self.logit_biases = [self.logit_biases[i] for i in indices]
|
||||
|
||||
new_grammars = []
|
||||
new_fsm_grammar_states = []
|
||||
@ -500,6 +524,9 @@ class HeterogeneousNextTokenChooser:
|
||||
fsm_grammar_states=(
|
||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||
),
|
||||
logit_biases=[
|
||||
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user