feat: use model name as adapter id in chat endpoints (#2128)

This commit is contained in:
drbh 2024-07-08 10:06:49 -04:00 committed by GitHub
parent 58effe78b5
commit 87ebb6477b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 4 deletions

View File

@ -384,7 +384,7 @@ pub struct CompletionRequest {
/// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
pub model: Option<String>,
/// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")]
@ -731,7 +731,7 @@ impl ChatCompletionChunk {
pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
pub model: Option<String>,
/// A list of messages comprising the conversation so far.
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]

View File

@ -597,6 +597,7 @@ async fn completions(
metrics::counter!("tgi_request_count").increment(1);
let CompletionRequest {
model,
max_tokens,
seed,
stop,
@ -665,7 +666,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
})
.collect();
@ -1001,6 +1002,7 @@ async fn chat_completions(
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
logprobs,
max_tokens,
messages,
@ -1106,7 +1108,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar,
..Default::default()
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
};