From 544f848bde9b23004456610be90e78b8da753b28 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 21 Feb 2024 11:31:11 -0500 Subject: [PATCH] fix: improve completion request params and comments --- router/src/lib.rs | 37 +++++++++++++++++++++++++++++++------ router/src/server.rs | 3 ++- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index f03cb53e..576b9236 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -268,25 +268,50 @@ fn default_parameters() -> GenerateParameters { #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] pub struct CompletionRequest { + /// UNUSED + #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] + /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. pub model: String, + + /// The prompt to generate completions for. + #[schema(example = "What is Deep Learning?")] pub prompt: String, + + /// The maximum number of tokens that can be generated in the chat completion. + #[serde(default)] + #[schema(default = "32")] pub max_tokens: Option, - /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will - /// make the output more random, while lower values like 0.2 will make it more - /// focused and deterministic. - /// - /// We generally recommend altering this or top_p but not both. + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while + /// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. #[serde(default)] #[schema(nullable = true, example = 1.0)] pub temperature: Option, + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the + /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. #[serde(default)] + #[serde(default)] + #[schema(nullable = true, example = 0.95)] pub top_p: Option, - pub stream: Option, + + #[serde(default = "bool::default")] + pub stream: bool, + + #[schema(nullable = true, example = 42)] pub seed: Option, + + /// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text. + /// please see the completion_template field in the model's tokenizer_config.json file for completion template. + #[serde(default)] pub suffix: Option, + #[serde(default)] pub repetition_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim. + #[serde(default)] + #[schema(example = "1.0")] pub frequency_penalty: Option, } diff --git a/router/src/server.rs b/router/src/server.rs index 4dd00745..020a976a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -574,8 +574,9 @@ async fn completions( Json(req): Json, ) -> Result)> { metrics::increment_counter!("tgi_request_count"); + + let stream = req.stream; let max_new_tokens = req.max_tokens.or(Some(100)); - let stream = req.stream.unwrap_or_default(); let seed = req.seed; let inputs = match infer.apply_completion_template(req.prompt, req.suffix) {