fix: enum CompletionType not ObjectType

This commit is contained in:
drbh 2024-06-27 13:07:05 +00:00
parent 39c6d10b5a
commit ae14f8931e
2 changed files with 22 additions and 26 deletions

View File

@ -445,7 +445,6 @@ pub struct CompletionRequest {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct Completion { pub(crate) struct Completion {
pub id: String, pub id: String,
pub object: ObjectType,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -466,7 +465,6 @@ pub(crate) struct CompletionComplete {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
pub object: ObjectType,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -562,12 +560,13 @@ pub(crate) struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Clone, Serialize, ToSchema)]
pub enum ObjectType { #[serde(tag = "object")]
#[serde(rename = "chat.completion")] enum CompletionType {
ChatCompletion,
#[serde(rename = "chat.completion.chunk")] #[serde(rename = "chat.completion.chunk")]
ChatCompletionChunk, ChatCompletionChunk(ChatCompletionChunk),
#[serde(rename = "chat.completion")]
ChatCompletion(ChatCompletion),
} }
impl ChatCompletion { impl ChatCompletion {
@ -606,7 +605,6 @@ impl ChatCompletion {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: ObjectType::ChatCompletion,
created, created,
model, model,
system_fingerprint, system_fingerprint,
@ -628,7 +626,6 @@ impl ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk { pub(crate) struct CompletionCompleteChunk {
pub id: String, pub id: String,
pub object: ObjectType,
pub created: u64, pub created: u64,
pub choices: Vec<CompletionComplete>, pub choices: Vec<CompletionComplete>,
pub model: String, pub model: String,
@ -638,7 +635,6 @@ pub(crate) struct CompletionCompleteChunk {
#[derive(Clone, Serialize, ToSchema)] #[derive(Clone, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
pub object: ObjectType,
#[schema(example = "1706270978")] #[schema(example = "1706270978")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -718,7 +714,6 @@ impl ChatCompletionChunk {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: ObjectType::ChatCompletionChunk,
created, created,
model, model,
system_fingerprint, system_fingerprint,

View File

@ -13,14 +13,15 @@ use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
Message, ObjectType, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
TokenizeResponse, Usage, Validation, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
VertexResponse,
}; };
use crate::{FunctionDefinition, ToolCall, ToolType}; use crate::{FunctionDefinition, ToolCall, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
@ -705,7 +706,6 @@ async fn completions(
event event
.json_data(CompletionCompleteChunk { .json_data(CompletionCompleteChunk {
id: "".to_string(), id: "".to_string(),
object: ObjectType::ChatCompletionChunk,
created: current_time, created: current_time,
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
@ -932,7 +932,6 @@ async fn completions(
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
object: ObjectType::ChatCompletion,
created: current_time, created: current_time,
model: info.model_id.clone(), model: info.model_id.clone(),
system_fingerprint: format!( system_fingerprint: format!(
@ -1153,14 +1152,16 @@ async fn chat_completions(
}; };
event event
.json_data(ChatCompletionChunk::new( .json_data(CompletionType::ChatCompletionChunk(
model_id.clone(), ChatCompletionChunk::new(
system_fingerprint.clone(), model_id.clone(),
content, system_fingerprint.clone(),
tool_calls, content,
current_time, tool_calls,
logprobs, current_time,
stream_token.details.map(|d| d.finish_reason.to_string()), logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()),
),
)) ))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e); println!("Failed to serialize ChatCompletionChunk: {:?}", e);
@ -1228,7 +1229,7 @@ async fn chat_completions(
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = CompletionType::ChatCompletion(ChatCompletion::new(
model_id, model_id,
system_fingerprint, system_fingerprint,
output, output,
@ -1236,7 +1237,7 @@ async fn chat_completions(
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls, tool_calls,
); ));
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())