diff --git a/router/src/lib.rs b/router/src/lib.rs index 534fa477..c870f522 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -445,7 +445,6 @@ pub struct CompletionRequest { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct Completion { pub id: String, - pub object: ObjectType, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -466,7 +465,6 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: ObjectType, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -562,12 +560,13 @@ pub(crate) struct Usage { pub total_tokens: u32, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum ObjectType { - #[serde(rename = "chat.completion")] - ChatCompletion, +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum CompletionType { #[serde(rename = "chat.completion.chunk")] - ChatCompletionChunk, + ChatCompletionChunk(ChatCompletionChunk), + #[serde(rename = "chat.completion")] + ChatCompletion(ChatCompletion), } impl ChatCompletion { @@ -606,7 +605,6 @@ impl ChatCompletion { }; Self { id: String::new(), - object: ObjectType::ChatCompletion, created, model, system_fingerprint, @@ -628,7 +626,6 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: ObjectType, pub created: u64, pub choices: Vec, pub model: String, @@ -638,7 +635,6 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: ObjectType, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -718,7 +714,6 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: ObjectType::ChatCompletionChunk, created, model, system_fingerprint, diff --git a/router/src/server.rs b/router/src/server.rs index 57197847..ce954c11 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,14 +13,15 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, ObjectType, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, - TokenizeResponse, Usage, Validation, + Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, + Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, + CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, + VertexResponse, }; use crate::{FunctionDefinition, ToolCall, ToolType}; use async_stream::__private::AsyncStream; @@ -705,7 +706,6 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: ObjectType::ChatCompletionChunk, created: current_time, choices: vec![CompletionComplete { @@ -932,7 +932,6 @@ async fn completions( let response = Completion { id: "".to_string(), - object: ObjectType::ChatCompletion, created: current_time, model: info.model_id.clone(), system_fingerprint: format!( @@ -1153,14 +1152,16 @@ async fn chat_completions( }; event - .json_data(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + .json_data(CompletionType::ChatCompletionChunk( + ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + stream_token.details.map(|d| d.finish_reason.to_string()), + ), )) .unwrap_or_else(|e| { println!("Failed to serialize ChatCompletionChunk: {:?}", e); @@ -1228,7 +1229,7 @@ async fn chat_completions( (None, Some(generation.generated_text)) }; // build the complete response object with the full text - let response = ChatCompletion::new( + let response = CompletionType::ChatCompletion(ChatCompletion::new( model_id, system_fingerprint, output, @@ -1236,7 +1237,7 @@ async fn chat_completions( generation.details.unwrap(), logprobs, tool_calls, - ); + )); // wrap generation inside a Vec to match api-inference Ok((headers, Json(response)).into_response())