From 87ebb6477bfc2a573f5ca7fa196fa87454dc6dc4 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 8 Jul 2024 10:06:49 -0400 Subject: [PATCH] feat: use model name as adapter id in chat endpoints (#2128) --- router/src/lib.rs | 4 ++-- router/src/server.rs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 165b2ad2..080c029a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -384,7 +384,7 @@ pub struct CompletionRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] @@ -731,7 +731,7 @@ impl ChatCompletionChunk { pub(crate) struct ChatRequest { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// A list of messages comprising the conversation so far. #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] diff --git a/router/src/server.rs b/router/src/server.rs index 4af8962e..4b52710d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -597,6 +597,7 @@ async fn completions( metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { + model, max_tokens, seed, stop, @@ -665,7 +666,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - ..Default::default() + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), }, }) .collect(); @@ -1001,6 +1002,7 @@ async fn chat_completions( let span = tracing::Span::current(); metrics::counter!("tgi_request_count").increment(1); let ChatRequest { + model, logprobs, max_tokens, messages, @@ -1106,7 +1108,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, - ..Default::default() + adapter_id: model.filter(|m| *m != "tgi").map(String::from), }, };