mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
added logit_bias param to REST and GRPC
This commit is contained in:
parent
e58ad6dd66
commit
20b05bc8ba
@ -42,6 +42,7 @@ pub async fn run(
|
||||
seed: 0,
|
||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||
watermark,
|
||||
logit_bias: vec![],
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
|
@ -66,6 +66,16 @@ message NextTokenChooserParameters {
|
||||
float repetition_penalty = 7;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
/// bias towards certain token sequences
|
||||
repeated LogitBias logit_bias = 9;
|
||||
}
|
||||
|
||||
/// a token sequence and its bias
|
||||
message LogitBias {
|
||||
/// string to bias towards (may be more than one token)
|
||||
string string = 1;
|
||||
/// bias for the string
|
||||
float bias = 2;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -124,6 +124,7 @@ impl Client {
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
watermark: true,
|
||||
logit_bias: vec![],
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 2,
|
||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||
PrefillTokens, Request, StoppingCriteriaParameters,
|
||||
PrefillTokens, Request, StoppingCriteriaParameters, LogitBias
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -44,6 +44,7 @@ impl Health {
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
logit_bias: vec![],
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -5,6 +5,7 @@ mod queue;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use infer::Infer;
|
||||
use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -135,6 +136,9 @@ pub(crate) struct GenerateParameters {
|
||||
example = "null"
|
||||
)]
|
||||
pub seed: Option<u64>,
|
||||
#[serde(default)]
|
||||
#[schema(default=json!({}), example=json!({"hello": 0.5}))]
|
||||
pub logit_bias: BTreeMap<String, f32>
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
@ -158,6 +162,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
details: false,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
logit_bias: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -322,6 +322,7 @@ mod tests {
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
watermark: false,
|
||||
logit_bias: vec![],
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -2,7 +2,7 @@
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::{thread_rng, Rng};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
@ -142,6 +142,7 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
decoder_input_details,
|
||||
logit_bias,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -238,6 +239,9 @@ impl Validation {
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
// transform logit_bias (received as a map) into a vector of LogitBias
|
||||
let logit_bias = logit_bias.into_iter().map(|(string, bias)| LogitBias { string, bias, }).collect();
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -247,6 +251,7 @@ impl Validation {
|
||||
do_sample,
|
||||
seed,
|
||||
watermark,
|
||||
logit_bias,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user