diff --git a/docs/openapi.json b/docs/openapi.json index 4a1ab6dd..72b073b1 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -367,7 +367,7 @@ "type": "integer", "format": "int32", "example": 1, - "minimum": 0.0 + "minimum": 0 }, "prefill": { "type": "array", @@ -380,13 +380,22 @@ "format": "int64", "example": 42, "nullable": true, - "minimum": 0.0 + "minimum": 0 }, "tokens": { "type": "array", "items": { "$ref": "#/components/schemas/Token" } + }, + "top_tokens": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + } } } }, @@ -432,7 +441,7 @@ "type": "integer", "format": "int32", "example": 1, - "minimum": 0.0 + "minimum": 0 }, "prefill": { "type": "array", @@ -445,13 +454,22 @@ "format": "int64", "example": 42, "nullable": true, - "minimum": 0.0 + "minimum": 0 }, "tokens": { "type": "array", "items": { "$ref": "#/components/schemas/Token" } + }, + "top_tokens": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + } } } }, @@ -486,8 +504,8 @@ "default": "null", "example": 1, "nullable": true, - "minimum": 0.0, - "exclusiveMinimum": 0.0 + "minimum": 0, + "exclusiveMinimum": 0 }, "decoder_input_details": { "type": "boolean", @@ -505,10 +523,10 @@ "max_new_tokens": { "type": "integer", "format": "int32", - "default": "20", - "minimum": 0.0, - "exclusiveMaximum": 512.0, - "exclusiveMinimum": 0.0 + "default": "null", + "example": "20", + "nullable": true, + "minimum": 0 }, "repetition_penalty": { "type": "number", @@ -516,7 +534,7 @@ "default": "null", "example": 1.03, "nullable": true, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": 0 }, "return_full_text": { "type": "boolean", @@ -530,8 +548,8 @@ "default": "null", "example": "null", "nullable": true, - "minimum": 0.0, - "exclusiveMinimum": 0.0 + "minimum": 0, + "exclusiveMinimum": 0 }, "stop": { "type": "array", @@ -549,7 +567,7 @@ "default": "null", "example": 0.5, "nullable": true, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": 0 }, "top_k": { "type": "integer", @@ -557,7 +575,16 @@ "default": "null", "example": 10, "nullable": true, - "exclusiveMinimum": 0.0 + "exclusiveMinimum": 0 + }, + "top_n_tokens": { + "type": "integer", + "format": "int32", + "default": "null", + "example": 5, + "nullable": true, + "minimum": 0, + "exclusiveMinimum": 0 }, "top_p": { "type": "number", @@ -565,15 +592,15 @@ "default": "null", "example": 0.95, "nullable": true, - "maximum": 1.0, - "exclusiveMinimum": 0.0 + "maximum": 1, + "exclusiveMinimum": 0 }, "truncate": { "type": "integer", "default": "null", "example": "null", "nullable": true, - "minimum": 0.0 + "minimum": 0 }, "typical_p": { "type": "number", @@ -581,8 +608,8 @@ "default": "null", "example": 0.95, "nullable": true, - "maximum": 1.0, - "exclusiveMinimum": 0.0 + "maximum": 1, + "exclusiveMinimum": 0 }, "watermark": { "type": "boolean", @@ -653,38 +680,38 @@ "type": "integer", "format": "int32", "example": "32000", - "minimum": 0.0 + "minimum": 0 }, "max_best_of": { "type": "integer", "example": "2", - "minimum": 0.0 + "minimum": 0 }, "max_concurrent_requests": { "type": "integer", "description": "Router Parameters", "example": "128", - "minimum": 0.0 + "minimum": 0 }, "max_input_length": { "type": "integer", "example": "1024", - "minimum": 0.0 + "minimum": 0 }, "max_stop_sequences": { "type": "integer", "example": "4", - "minimum": 0.0 + "minimum": 0 }, "max_total_tokens": { "type": "integer", "example": "2048", - "minimum": 0.0 + "minimum": 0 }, "max_waiting_tokens": { "type": "integer", "example": "20", - "minimum": 0.0 + "minimum": 0 }, "model_device_type": { "type": "string", @@ -717,7 +744,7 @@ "validation_workers": { "type": "integer", "example": "2", - "minimum": 0.0 + "minimum": 0 }, "version": { "type": "string", @@ -743,7 +770,7 @@ "type": "integer", "format": "int32", "example": 0, - "minimum": 0.0 + "minimum": 0 }, "logprob": { "type": "number", @@ -771,14 +798,14 @@ "type": "integer", "format": "int32", "example": 1, - "minimum": 0.0 + "minimum": 0 }, "seed": { "type": "integer", "format": "int64", "example": 42, "nullable": true, - "minimum": 0.0 + "minimum": 0 } } }, @@ -794,6 +821,7 @@ "$ref": "#/components/schemas/StreamDetails" } ], + "default": "null", "nullable": true }, "generated_text": { @@ -804,6 +832,12 @@ }, "token": { "$ref": "#/components/schemas/Token" + }, + "top_tokens": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } } } }, @@ -820,7 +854,7 @@ "type": "integer", "format": "int32", "example": 0, - "minimum": 0.0 + "minimum": 0 }, "logprob": { "type": "number", diff --git a/router/src/lib.rs b/router/src/lib.rs index 76e70bb7..560b8f74 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -107,8 +107,8 @@ pub(crate) struct GenerateParameters { #[schema(default = "false", example = true)] pub do_sample: bool, #[serde(default = "default_max_new_tokens")] - #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] - pub max_new_tokens: u32, + #[schema(nullable = true, default = "null", example = "20")] + pub max_new_tokens: Option, #[serde(default)] #[schema(nullable = true, default = "null", example = false)] pub return_full_text: Option, @@ -140,8 +140,8 @@ pub(crate) struct GenerateParameters { pub top_n_tokens: Option, } -fn default_max_new_tokens() -> u32 { - 20 +fn default_max_new_tokens() -> Option { + None } fn default_parameters() -> GenerateParameters { diff --git a/router/src/validation.rs b/router/src/validation.rs index 36cbfb9b..9adedc5b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -67,8 +67,8 @@ impl Validation { &self, inputs: String, truncate: Option, - max_new_tokens: u32, - ) -> Result<(String, usize), ValidationError> { + max_new_tokens: Option, + ) -> Result<(String, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -84,6 +84,11 @@ impl Validation { let (inputs, input_length) = response_receiver.await.unwrap()?; // Get total tokens + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { + max_new_tokens + } else { + self.max_total_tokens.saturating_sub(input_length) as u32 + }; let total_tokens = input_length + max_new_tokens as usize; // Validate MaxTotalTokens @@ -104,7 +109,7 @@ impl Validation { } metrics::histogram!("tgi_request_input_length", input_length as f64); - Ok((inputs, input_length)) + Ok((inputs, input_length, max_new_tokens)) } // Return inputs without validation else { @@ -112,6 +117,11 @@ impl Validation { // However, the inputs will be truncated by the python servers // We make sure that truncate + max_new_tokens <= self.max_total_tokens let input_length = truncate.unwrap_or(self.max_input_length); + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { + max_new_tokens + } else { + self.max_total_tokens.saturating_sub(input_length) as u32 + }; // Validate MaxNewTokens if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { @@ -121,7 +131,7 @@ impl Validation { )); } - Ok((inputs, input_length)) + Ok((inputs, input_length, max_new_tokens)) } } @@ -200,7 +210,7 @@ impl Validation { }) .unwrap_or(Ok(0))?; - if max_new_tokens == 0 { + if max_new_tokens == Some(0) { return Err(ValidationError::NegativeMaxNewTokens); } @@ -247,7 +257,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length) = self + let (inputs, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -426,7 +436,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, max_new_tokens) + .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { Err(ValidationError::MaxNewTokens(1, 10)) => (), @@ -455,7 +465,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, max_new_tokens) + .validate_input("Hello".to_string(), None, Some(max_new_tokens)) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -534,7 +544,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: Some(0.99), - max_new_tokens: 1, ..default_parameters() }, }) @@ -549,7 +558,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: None, - max_new_tokens: 1, ..default_parameters() }, }) @@ -596,7 +604,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(4), - max_new_tokens: 1, ..default_parameters() }, }) @@ -608,7 +615,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(0), - max_new_tokens: 1, ..default_parameters() }, }) @@ -620,7 +626,6 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: None, - max_new_tokens: 1, ..default_parameters() }, })