fix: prefer enum for chat object

This commit is contained in:
drbh 2024-06-26 22:54:00 +00:00
parent fb98ab273f
commit f98f498473
2 changed files with 21 additions and 12 deletions

View File

@ -442,10 +442,11 @@ pub struct CompletionRequest {
pub stop: Option<Vec<String>>, pub stop: Option<Vec<String>>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct Completion { pub(crate) struct Completion {
pub id: String, pub id: String,
pub object: String, #[schema(default = "ObjectType::ChatCompletion")]
pub object: ObjectType,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -466,7 +467,7 @@ pub(crate) struct CompletionComplete {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
pub object: String, pub object: ObjectType,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -562,6 +563,14 @@ pub(crate) struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ObjectType {
#[serde(rename = "chat.completion")]
ChatCompletion,
#[serde(rename = "chat.completion.chunk")]
ChatCompletionChunk,
}
impl ChatCompletion { impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
@ -598,7 +607,7 @@ impl ChatCompletion {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: "chat.completion".into(), object: ObjectType::ChatCompletion,
created, created,
model, model,
system_fingerprint, system_fingerprint,
@ -620,7 +629,7 @@ impl ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk { pub(crate) struct CompletionCompleteChunk {
pub id: String, pub id: String,
pub object: String, pub object: ObjectType,
pub created: u64, pub created: u64,
pub choices: Vec<CompletionComplete>, pub choices: Vec<CompletionComplete>,
pub model: String, pub model: String,
@ -630,7 +639,7 @@ pub(crate) struct CompletionCompleteChunk {
#[derive(Clone, Serialize, ToSchema)] #[derive(Clone, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
pub object: String, pub object: ObjectType,
#[schema(example = "1706270978")] #[schema(example = "1706270978")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -710,7 +719,7 @@ impl ChatCompletionChunk {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: "chat.completion.chunk".to_string(), object: ObjectType::ChatCompletionChunk,
created, created,
model, model,
system_fingerprint, system_fingerprint,

View File

@ -12,9 +12,9 @@ use crate::kserve::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Message, ObjectType, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
Token, TokenizeResponse, Usage, Validation, TokenizeResponse, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -705,7 +705,7 @@ async fn completions(
event event
.json_data(CompletionCompleteChunk { .json_data(CompletionCompleteChunk {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: ObjectType::ChatCompletionChunk,
created: current_time, created: current_time,
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
@ -932,7 +932,7 @@ async fn completions(
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: ObjectType::ChatCompletion,
created: current_time, created: current_time,
model: info.model_id.clone(), model: info.model_id.clone(),
system_fingerprint: format!( system_fingerprint: format!(