feat: use model name as adapter id in chat endpoints

This commit is contained in:
drbh 2024-06-26 23:21:39 +00:00
parent be2d38032a
commit 29a1137409
2 changed files with 6 additions and 4 deletions

View File

@ -370,7 +370,7 @@ pub struct CompletionRequest {
/// UNUSED /// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, pub model: Option<String>,
/// The prompt to generate completions for. /// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
@ -706,7 +706,7 @@ impl ChatCompletionChunk {
pub(crate) struct ChatRequest { pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, pub model: Option<String>,
/// A list of messages comprising the conversation so far. /// A list of messages comprising the conversation so far.
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]

View File

@ -606,6 +606,7 @@ async fn completions(
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let CompletionRequest { let CompletionRequest {
model,
max_tokens, max_tokens,
seed, seed,
stop, stop,
@ -673,7 +674,7 @@ async fn completions(
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
..Default::default() adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
}, },
}) })
.collect(); .collect();
@ -1011,6 +1012,7 @@ async fn chat_completions(
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let ChatRequest { let ChatRequest {
model,
logprobs, logprobs,
max_tokens, max_tokens,
messages, messages,
@ -1116,7 +1118,7 @@ async fn chat_completions(
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar, grammar,
..Default::default() adapter_id: model.filter(|m| *m != "tgi").map(String::from),
}, },
}; };