diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml deleted file mode 100644 index ff2928c4..00000000 --- a/.github/workflows/client-tests.yaml +++ /dev/null @@ -1,26 +0,0 @@ -name: Python Client Tests - -on: - pull_request: - paths: - - ".github/workflows/client-tests.yaml" - - "clients/python/**" - -jobs: - run_tests: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.9 - - name: Install - run: | - cd clients/python && pip install . - - name: Run tests - run: | - pip install pytest pytest-asyncio - export HF_TOKEN=${{ secrets.HF_TOKEN }} - make python-client-tests diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 968c1f45..b5dd081f 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::collections::HashMap; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -181,6 +182,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index dc3bcdde..99688668 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -5,6 +5,7 @@ use crate::{ClientError, Result}; use crate::v3::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; +use std::collections::HashMap; use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; @@ -244,6 +245,7 @@ impl Health for ShardedClient { watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index c9a9335d..0e40fdf2 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -429,6 +429,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: None, + logit_bias: None, }, stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index f4942f64..9d5e042a 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::collections::HashMap; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -181,6 +182,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 4701c560..b9667dc4 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -10,6 +10,7 @@ use crate::client::{ use crate::client::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; +use std::collections::HashMap; use tonic::transport::Uri; use tracing::instrument; @@ -232,6 +233,7 @@ impl Health for ShardedClient { watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: HashMap::new(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 8cfee3a5..4f0bbdc0 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -5,6 +5,7 @@ use crate::client::{ }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; +use std::collections::HashMap; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; @@ -542,6 +543,14 @@ impl From for NextTokenChooserParameters { watermark: value.watermark, grammar, grammar_type: grammar_type.into(), + logit_bias: value + .logit_bias + .map(|bias| { + bias.into_iter() + .map(|(token, bias)| (token.to_string(), bias as i32)) + .collect::>() + }) + .unwrap_or_default(), } } } @@ -588,6 +597,7 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: None, + logit_bias: None, }, stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index bb4b6a77..cd1f9446 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -47,6 +47,7 @@ pub async fn run( watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, + logit_bias: std::collections::HashMap::new(), }; // Initialize terminal properties diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 6f51c153..b9e00fae 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, field_validator, ConfigDict -from typing import Optional, List, Union, Any +from typing import Optional, List, Union, Any, Dict from text_generation.errors import ValidationError @@ -137,7 +137,7 @@ class ChatRequest(BaseModel): # decreasing the model's likelihood to repeat the same line verbatim. frequency_penalty: Optional[float] = None # Bias values for token selection - logit_bias: Optional[List[float]] = None + logit_bias: Optional[Dict[str, int]] = None # Whether to return log probabilities logprobs: Optional[bool] = None # Number of most likely tokens to return at each position diff --git a/docs/openapi.json b/docs/openapi.json index ff63c3da..e309a847 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -995,12 +995,12 @@ "nullable": true }, "logit_bias": { - "type": "array", - "items": { - "type": "number", - "format": "float" + "type": "object", + "description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", + "additionalProperties": { + "type": "integer", + "format": "int32" }, - "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", "nullable": true }, "logprobs": { @@ -1589,6 +1589,17 @@ "default": "null", "nullable": true }, + "logit_bias": { + "type": "object", + "description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.", + "default": "null", + "additionalProperties": { + "type": "integer", + "format": "int32" + }, + "example": "{\"1923\": 100, \"1924\": -100}", + "nullable": true + }, "max_new_tokens": { "type": "integer", "format": "int32", diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json new file mode 100644 index 00000000..b683b090 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_baseline.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello! How can I help you today?", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1745337495, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 21, + "total_tokens": 31 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json new file mode 100644 index 00000000..98a90c07 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_english_to_spanish.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "¡Hola! ¿Cómo puedo ayudarte?", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1746486174, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 21, + "total_tokens": 31 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json new file mode 100644 index 00000000..47b48462 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_multiple_tokens.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "Chat!", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1746486174, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.2.3-dev0-native", + "usage": { + "completion_tokens": 3, + "prompt_tokens": 25, + "total_tokens": 28 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json new file mode 100644 index 00000000..c9288d1c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_logit_bias/test_logit_bias_streaming.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1746486174, + "id": "", + "model": "Qwen/Qwen2-VL-2B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "3.2.3-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_logit_bias.py b/integration-tests/models/test_flash_logit_bias.py new file mode 100644 index 00000000..5ce793b7 --- /dev/null +++ b/integration-tests/models/test_flash_logit_bias.py @@ -0,0 +1,109 @@ +import pytest + + +@pytest.fixture(scope="module") +def logit_bias_model_handle(launcher): + with launcher("Qwen/Qwen2-VL-2B-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def logit_bias_model(logit_bias_model_handle): + await logit_bias_model_handle.health(300) + return logit_bias_model_handle.client + + +@pytest.mark.private +async def test_logit_bias_english_to_spanish(logit_bias_model, response_snapshot): + """Test that setting negative bias on English tokens forces output to be in Spanish""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=10, + logit_bias={"9707": -100}, # Bias against 'Hello' token + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + ) + assert "¡Hola!" in response.choices[0].message.content + assert "Hello" not in response.choices[0].message.content + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_baseline(logit_bias_model, response_snapshot): + """Test baseline behavior without logit bias for comparison""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=10, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + ) + assert "Hello" in response.choices[0].message.content + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_multiple_tokens(logit_bias_model, response_snapshot): + """Test applying bias to multiple tokens simultaneously""" + response = await logit_bias_model.chat( + seed=42, + max_tokens=15, + logit_bias={ + "9707": -100, # Bias against 'Hello' token + "2880": -100, # Bias against 'hi' token + }, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Give me a one-word greeting"}, + ], + }, + ], + ) + assert "Hello" not in response.choices[0].message.content.lower() + assert "hi" not in response.choices[0].message.content.lower() + assert response == response_snapshot + + +@pytest.mark.private +async def test_logit_bias_streaming(logit_bias_model, response_snapshot): + """Test logit bias works correctly with streaming enabled""" + responses = await logit_bias_model.chat( + seed=42, + max_tokens=10, + logit_bias={"9707": -100}, # Bias against 'Hello' token + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "say Hello"}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert "¡Hola!" in generated + assert "Hello" not in generated + assert last_response == response_snapshot diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 02980b6f..1018e709 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -104,6 +104,8 @@ message NextTokenChooserParameters { string grammar = 10; /// grammar type GrammarType grammar_type = 11; + /// logit bias dictionary mapping token string to bias value + map logit_bias = 12; } message StoppingCriteriaParameters { diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc2..88ab86e9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError}; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; @@ -431,6 +432,16 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub adapter_id: Option, + + /// Modify the likelihood of specified tokens appearing in the completion. + /// Accepts a hash map that maps token strings to an associated bias value. + #[serde(default)] + #[schema( + nullable = true, + default = "null", + example = "{\"1923\": 100, \"1924\": -100}" + )] + pub logit_bias: Option>, } fn default_parameters() -> GenerateParameters { @@ -454,9 +465,9 @@ fn default_parameters() -> GenerateParameters { top_n_tokens: None, grammar: None, adapter_id: None, + logit_bias: None, } } - #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[serde(try_from = "PromptDeserializer")] pub struct Prompt(pub Vec); @@ -841,14 +852,13 @@ pub(crate) struct ChatRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, - /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should /// result in a ban or exclusive selection of the relevant token. #[serde(default)] - pub logit_bias: Option>, + pub logit_bias: Option>, /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// output token returned in the content of message. @@ -954,6 +964,7 @@ impl ChatRequest { frequency_penalty, top_p, top_logprobs, + logit_bias, .. } = self; @@ -1029,6 +1040,7 @@ impl ChatRequest { top_n_tokens: top_logprobs, grammar, adapter_id: model.filter(|m| *m != "tgi"), + logit_bias, }, }, using_tools, diff --git a/router/src/server.rs b/router/src/server.rs index 5fbe0403..14bcb86a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -798,6 +798,7 @@ pub(crate) async fn completions( top_n_tokens: None, grammar: None, adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + logit_bias: None, }, }) .collect(); diff --git a/router/src/validation.rs b/router/src/validation.rs index 28c7f2f8..b3d4dd9a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -34,6 +34,7 @@ pub struct Validation { max_input_length: usize, max_total_tokens: usize, disable_grammar_support: bool, + vocab_size: u32, /// Channel to communicate with the background tokenization task sender: mpsc::UnboundedSender, } @@ -88,6 +89,19 @@ impl Validation { validation_sender }; + let vocab_size = match &tokenizer { + Tokenizer::Python { tokenizer_name, .. } => { + warn!( + "Tokenizer {} is not supported for validation", + tokenizer_name + ); + 0 + } + Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false), + } + .try_into() + .unwrap_or(0); + Self { max_best_of, sender, @@ -96,6 +110,7 @@ impl Validation { max_input_length, max_total_tokens, disable_grammar_support, + vocab_size, } } @@ -409,6 +424,37 @@ impl Validation { None => None, }; + // Validate logit bias and convert to a vector of (token_id, bias_value) + let logit_bias = request + .parameters + .logit_bias + .as_ref() + .filter(|bias_map| !bias_map.is_empty()) + .map(|bias_map| { + bias_map + .iter() + .map(|(token_str, &bias_value)| { + let token_id: u32 = token_str.parse().map_err(|_| { + ValidationError::LogitBiasInvalid(format!( + "Token ID {token_str} is not a valid number." + )) + })?; + + if token_id >= self.vocab_size { + return Err(ValidationError::LogitBiasInvalid(format!( + "Token ID {token_id} is out of range (0..{}).", + self.vocab_size - 1 + ))); + } + + Ok((token_id, bias_value as f32)) + }) + .collect::, _>>() + }) + // convert Option> to Result, E> to throw + // if any of the token IDs are invalid + .transpose()?; + let parameters = ValidParameters { temperature, repetition_penalty, @@ -420,6 +466,7 @@ impl Validation { seed, watermark, grammar, + logit_bias, }; let stopping_parameters = ValidStoppingParameters { max_new_tokens, @@ -902,6 +949,8 @@ pub struct ValidParameters { pub watermark: bool, /// / grammar (applied if not empty) pub grammar: Option, + /// / logit bias + pub logit_bias: Option>, } #[derive(Debug, Clone)] @@ -997,6 +1046,8 @@ pub enum ValidationError { FailedFetchImage(#[from] reqwest::Error), #[error("{0} modality is not supported")] UnsupportedModality(&'static str), + #[error("logit_bias is not valid: {0}")] + LogitBiasInvalid(String), } #[cfg(test)] diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index b76dbe68..4fed5092 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -725,6 +725,15 @@ class VlmCausalLM(FlashCausalLM): **kwargs, ) + if self.config.vocab_size != self.tokenizer.vocab_size: + logger.warning( + f"Tokenizer vocab size {self.tokenizer.vocab_size} does not match model vocab size {self.config.vocab_size}. Updating tokenizer vocab size." + ) + # TODO: HUGE HACK! This is a workaround to update the vocab size + # in the tokenizer. When the tokenizer is updated within the model + # the vocab size is not updated in the tokenizer. + self.tokenizer._vocab_size = self.config.vocab_size + @property def batch_type(self) -> Type[VlmCausalLMBatch]: return self.batch_class diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 64a285b9..36638c37 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -623,3 +623,85 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): new_fsms.append(self.fsms[i]) self.fsms = new_fsms return self + + +class LogitBiasProcessor(LogitsProcessor): + """ + `LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their + corresponding bias values. Bias are applied to the logits during each forward pass. + + Supports token IDs provided as strings (e.g., {"9707": -100}). + """ + + def __init__( + self, + logit_biases: dict, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ): + assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases" + + # use _vocab_size or fallback to tokenizer.vocab_size if not available + self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size) + + # Convert keys to integers and values to a list + token_ids = torch.tensor( + [int(k) for k in logit_biases.keys()], dtype=torch.long + ) + bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float) + + # Create a tensor and directly copy bias values at the corresponding indices + self.bias_tensor = torch.zeros(self.vocab_size, dtype=torch.float) + self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype)) + return scores + + +class HeterogeneousLogitBiasProcessor(LogitsProcessor): + """ + Process logits with different logit biases for each sequence in the batch. + """ + + def __init__( + self, + logit_biases: List[Optional[dict]], + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ): + assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases" + + self.tokenizer = tokenizer + self.logit_biases = logit_biases + + # use _vocab_size or fallback to tokenizer.vocab_size if not available + self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size) + + # Create batch_size x vocab_size bias matrix + self.bias_matrix = torch.zeros( + (len(logit_biases), self.vocab_size), dtype=torch.float, device=device + ) + + # for each logit bias dictionary, convert keys to integers and values to a list + for i, logit_bias in enumerate(logit_biases): + token_ids = torch.tensor( + [int(k) for k in logit_bias.keys()], dtype=torch.long + ).to(device=device) + bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to( + device=device + ) + # Create a tensor and directly copy bias values at the corresponding indices + self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype)) + return scores + + def filter(self, indices): + new_logit_biases = [self.logit_biases[i] for i in indices] + if not any(bias and len(bias) > 0 for bias in new_logit_biases): + return None + return HeterogeneousLogitBiasProcessor( + new_logit_biases, self.tokenizer, self.device + ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665..e4a2698e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -7,6 +7,7 @@ from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, GrammarLogitProcessor, + LogitBiasProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor, @@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousGrammarLogitProcessor, + HeterogeneousLogitBiasProcessor, static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -38,6 +40,7 @@ class NextTokenChooser: grammar: str = "", grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, fsm_grammar_state: int = 0, + logit_bias: Optional[dict] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -57,6 +60,11 @@ class NextTokenChooser: if grammar != "" else None ) + self.logit_bias_processor = ( + LogitBiasProcessor(logit_bias, tokenizer, device) + if logit_bias is not None and len(logit_bias) > 0 + else None + ) self.tokenizer = tokenizer has_warpers = ( @@ -87,6 +95,8 @@ class NextTokenChooser: scores = self.frequency_processor(input_ids, scores) if self.grammar_processor is not None: scores = self.grammar_processor(scores, self.fsm_grammar_state) + if self.logit_bias_processor is not None: + scores = self.logit_bias_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -125,6 +135,7 @@ class NextTokenChooser: tokenizer=tokenizer, grammar=pb.grammar, grammar_type=pb.grammar_type, + logit_bias=pb.logit_bias, ) @@ -248,6 +259,7 @@ class HeterogeneousNextTokenChooser: grammars: List[str], grammar_types: List[int], fsm_grammar_states=List[int], + logit_biases: List[Optional[dict]] = None, ): warpers = [] @@ -287,6 +299,12 @@ class HeterogeneousNextTokenChooser: else None ) + self.logit_bias_processor = ( + HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device) + if any(logit_bias for logit_bias in logit_biases) + else None + ) + if any(x != 1.0 for x in temperature): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -322,6 +340,7 @@ class HeterogeneousNextTokenChooser: self.fsm_grammar_states = fsm_grammar_states self.grammars = grammars self.grammar_types = grammar_types + self.logit_biases = logit_biases def __call__( self, @@ -353,6 +372,8 @@ class HeterogeneousNextTokenChooser: _scores = self.frequency_processor(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + if self.logit_bias_processor is not None: + _scores = self.logit_bias_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) @@ -444,6 +465,9 @@ class HeterogeneousNextTokenChooser: if self.grammar_processor is not None: self.grammar_processor = self.grammar_processor.filter(indices) + if self.logit_bias_processor is not None: + self.logit_bias_processor = self.logit_bias_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) @@ -453,6 +477,7 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + self.logit_biases = [self.logit_biases[i] for i in indices] new_grammars = [] new_fsm_grammar_states = [] @@ -500,6 +525,7 @@ class HeterogeneousNextTokenChooser: fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), + logit_biases=[pb_.logit_bias for pb_ in pb], )