Change deltas too.

This commit is contained in:
Nicolas Patry 2024-05-16 13:23:24 +02:00
parent 2a87dd7274
commit 961a873305

View File

@ -11,6 +11,7 @@ use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit; use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
@ -539,9 +540,25 @@ impl ChatCompletion {
content: vec![MessageChunk::Text(Text { text: output })], content: vec![MessageChunk::Text(Text { text: output })],
name: None, name: None,
}), }),
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { tool_calls }), (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
_ => { role: "assistant".to_string(),
todo!("Implement error for invalid tool vs chat"); tool_calls,
}),
(Some(output), Some(_)) => {
warn!("Received both chat and tool call");
OutputMessage::ChatMessage(Message {
role: "assistant".into(),
content: vec![MessageChunk::Text(Text { text: output })],
name: None,
})
}
(None, None) => {
warn!("Didn't receive an answer");
OutputMessage::ChatMessage(Message {
role: "assistant".into(),
content: vec![],
name: None,
})
} }
}; };
Self { Self {
@ -574,6 +591,7 @@ pub(crate) struct CompletionCompleteChunk {
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
@ -594,21 +612,20 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct ChatCompletionDelta { pub struct ToolCallDelta {
#[schema(example = "user")] #[schema(example = "assistant")]
// TODO Modify this to a true enum. role: String,
#[serde(default, skip_serializing_if = "Option::is_none")] tool_calls: DeltaToolCall,
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
enum ChatCompletionDelta {
Chat(TextMessage),
Tool(ToolCallDelta),
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct DeltaToolCall { pub(crate) struct DeltaToolCall {
pub index: u32, pub index: u32,
pub id: String, pub id: String,
@ -616,7 +633,7 @@ pub(crate) struct DeltaToolCall {
pub function: Function, pub function: Function,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function { pub(crate) struct Function {
pub name: Option<String>, pub name: Option<String>,
pub arguments: String, pub arguments: String,
@ -634,15 +651,13 @@ impl ChatCompletionChunk {
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) { let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: Some("assistant".to_string()), role: "assistant".to_string(),
content: Some(delta), content: delta,
tool_calls: None, }),
}, (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
(None, Some(tool_calls)) => ChatCompletionDelta { role: "assistant".to_string(),
role: Some("assistant".to_string()), tool_calls: DeltaToolCall {
content: None,
tool_calls: Some(DeltaToolCall {
index: 0, index: 0,
id: String::new(), id: String::new(),
r#type: "function".to_string(), r#type: "function".to_string(),
@ -650,13 +665,12 @@ impl ChatCompletionChunk {
name: None, name: None,
arguments: tool_calls[0].to_string(), arguments: tool_calls[0].to_string(),
}, },
}), },
}, }),
(None, None) => ChatCompletionDelta { (None, None) => ChatCompletionDelta::Chat(TextMessage {
role: None, role: "assistant".to_string(),
content: None, content: "".to_string(),
tool_calls: None, }),
},
}; };
Self { Self {
id: String::new(), id: String::new(),
@ -982,6 +996,8 @@ impl From<Message> for TextMessage {
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct ToolCallMessage { pub struct ToolCallMessage {
#[schema(example = "assistant")]
role: String,
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
tool_call_id: String, tool_call_id: String,
} }
@ -1265,6 +1281,7 @@ mod tests {
}, },
}, },
] ]
} }
] ]
}); });