This commit is contained in:
OlivierDehaene 2023-03-09 10:38:11 +01:00
parent 1a2d68250a
commit a376d8bc59
3 changed files with 9 additions and 2 deletions

View File

@ -56,12 +56,15 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(default = "None", example = false)] #[schema(default = "null", example = false)]
pub return_full_text: Option<bool>, pub return_full_text: Option<bool>,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))] #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "null", example = "null")]
pub truncate: Option<i32>,
#[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub watermark: bool, pub watermark: bool,
#[serde(default)] #[serde(default)]
@ -86,6 +89,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
truncate: None,
watermark: false, watermark: false,
details: false, details: false,
seed: None, seed: None,

View File

@ -73,6 +73,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
max_new_tokens: 1, max_new_tokens: 1,
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
truncate: None,
watermark: false, watermark: false,
details: false, details: false,
seed: None, seed: None,

View File

@ -225,7 +225,9 @@ fn validate(
// Get the number of tokens in the input // Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) { match tokenizer.encode(request.inputs.clone(), true) {
Ok(encoding) => { Ok(mut encoding) => {
encoding.truncate()
let input_length = encoding.len(); let input_length = encoding.len();
let total_tokens = input_length + max_new_tokens as usize; let total_tokens = input_length + max_new_tokens as usize;