mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added logit_bias param to REST and GRPC
This commit is contained in:
parent
e58ad6dd66
commit
20b05bc8ba
@ -42,6 +42,7 @@ pub async fn run(
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
watermark,
|
watermark,
|
||||||
|
logit_bias: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
@ -66,6 +66,16 @@ message NextTokenChooserParameters {
|
|||||||
float repetition_penalty = 7;
|
float repetition_penalty = 7;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
|
/// bias towards certain token sequences
|
||||||
|
repeated LogitBias logit_bias = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// a token sequence and its bias
|
||||||
|
message LogitBias {
|
||||||
|
/// string to bias towards (may be more than one token)
|
||||||
|
string string = 1;
|
||||||
|
/// bias for the string
|
||||||
|
float bias = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -124,6 +124,7 @@ impl Client {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
|
logit_bias: vec![],
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 2,
|
max_new_tokens: 2,
|
||||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
|
|||||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||||
PrefillTokens, Request, StoppingCriteriaParameters,
|
PrefillTokens, Request, StoppingCriteriaParameters, LogitBias
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -44,6 +44,7 @@ impl Health {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
logit_bias: vec![],
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -5,6 +5,7 @@ mod queue;
|
|||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
|
use std::collections::BTreeMap;
|
||||||
use infer::Infer;
|
use infer::Infer;
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -113,7 +114,7 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[schema(nullable = true, default = "null", example = false)]
|
#[schema(nullable = true, default = "null", example = false)]
|
||||||
pub return_full_text: Option<bool>,
|
pub return_full_text: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
@ -135,6 +136,9 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = "null"
|
example = "null"
|
||||||
)]
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default=json!({}), example=json!({"hello": 0.5}))]
|
||||||
|
pub logit_bias: BTreeMap<String, f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
fn default_max_new_tokens() -> u32 {
|
||||||
@ -158,6 +162,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
details: false,
|
details: false,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
logit_bias: BTreeMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -322,6 +322,7 @@ mod tests {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
logit_bias: vec![],
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::TruncationDirection;
|
use tokenizers::TruncationDirection;
|
||||||
@ -142,6 +142,7 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
|
logit_bias,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -238,6 +239,9 @@ impl Validation {
|
|||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// transform logit_bias (received as a map) into a vector of LogitBias
|
||||||
|
let logit_bias = logit_bias.into_iter().map(|(string, bias)| LogitBias { string, bias, }).collect();
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
@ -247,6 +251,7 @@ impl Validation {
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
|
logit_bias,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user