From f98f4984734a5bb0d8c2537504f1aeb751c9c8d8 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 26 Jun 2024 22:54:00 +0000 Subject: [PATCH] fix: prefer enum for chat object --- router/src/lib.rs | 23 ++++++++++++++++------- router/src/server.rs | 10 +++++----- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index a5b97af3..759b9639 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -442,10 +442,11 @@ pub struct CompletionRequest { pub stop: Option>, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] +#[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct Completion { pub id: String, - pub object: String, + #[schema(default = "ObjectType::ChatCompletion")] + pub object: ObjectType, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -466,7 +467,7 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: String, + pub object: ObjectType, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -562,6 +563,14 @@ pub(crate) struct Usage { pub total_tokens: u32, } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ObjectType { + #[serde(rename = "chat.completion")] + ChatCompletion, + #[serde(rename = "chat.completion.chunk")] + ChatCompletionChunk, +} + impl ChatCompletion { pub(crate) fn new( model: String, @@ -598,7 +607,7 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "chat.completion".into(), + object: ObjectType::ChatCompletion, created, model, system_fingerprint, @@ -620,7 +629,7 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: String, + pub object: ObjectType, pub created: u64, pub choices: Vec, pub model: String, @@ -630,7 +639,7 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: String, + pub object: ObjectType, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -710,7 +719,7 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "chat.completion.chunk".to_string(), + object: ObjectType::ChatCompletionChunk, created, model, system_fingerprint, diff --git a/router/src/server.rs b/router/src/server.rs index 0cb08d4e..57197847 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,9 +12,9 @@ use crate::kserve::{ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, - HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, - Token, TokenizeResponse, Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, + Message, ObjectType, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, + TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -705,7 +705,7 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: "text_completion".to_string(), + object: ObjectType::ChatCompletionChunk, created: current_time, choices: vec![CompletionComplete { @@ -932,7 +932,7 @@ async fn completions( let response = Completion { id: "".to_string(), - object: "text_completion".to_string(), + object: ObjectType::ChatCompletion, created: current_time, model: info.model_id.clone(), system_fingerprint: format!(