From 82f87ada6f08114ae198abb0829d087f311cf5bc Mon Sep 17 00:00:00 2001 From: Jacob Keisling Date: Tue, 23 Jan 2024 08:55:05 -0600 Subject: [PATCH] Disable `decoder_input_details` on OpenAI-compatible chat streaming, pass temp and top-k from API (#1470) This PR makes some minor tweaks to the new OpenAI-compatible chat endpoint #1427 in `GenerateParameters`: - Disables `decoder_input_details` when streaming is enabled. This was causing all streaming chat requests to fail before, since [`decoder_input_details`==true is not enabled when streaming tokens](https://github.com/huggingface/text-generation-inference/blob/98e5faff9daec6170cc2b0f963f2d73cf846b341/router/src/validation.rs#L406). - Passes through `temperature` and `top_p` hyperparameters from the API request to `GenerateParameters` ## Testing ```bash curl localhost:8080/v1/chat/completions \ -X POST \ -d '{ "model": "", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is deep learning?" } ], "stream": true, "max_tokens": 20 }' \ -H 'Content-Type: application/json' ``` Should work correctly. Currently, most recent release from `main` returns error: ``` data:{"error":"Input validation error: `decoder_input_details` == true is not supported when streaming tokens","error_type":"validation"} ``` It's my first time contributing to this project, so I could be missing something. Would especially appreciate @drbh's eyes on this one --- router/src/lib.rs | 12 ++++++++++++ router/src/server.rs | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 983079d6..894ab466 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -365,6 +365,18 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, example = 42)] pub seed: Option, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while + /// lower values like 0.2 will make it more focused and deterministic. + /// + /// We generally recommend altering this or `top_p` but not both. + #[serde(default)] + pub temperature: Option, + + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the + /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + #[serde(default)] + pub top_p: Option, } #[derive(Clone, Serialize, Deserialize)] diff --git a/router/src/server.rs b/router/src/server.rs index cf1d94a6..aa1ad202 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -592,10 +592,10 @@ async fn chat_completions( inputs: inputs.to_string(), parameters: GenerateParameters { best_of: None, - temperature: None, + temperature: req.temperature, repetition_penalty, top_k: None, - top_p: None, + top_p: req.top_p, typical_p: None, do_sample: true, max_new_tokens, @@ -604,7 +604,7 @@ async fn chat_completions( truncate: None, watermark: false, details: true, - decoder_input_details: true, + decoder_input_details: !stream, seed, top_n_tokens: None, },