Change deltas too.

This commit is contained in:
Nicolas Patry 2024-05-16 13:23:24 +02:00
parent 2a87dd7274
commit 961a873305

View File

@ -11,6 +11,7 @@ use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
@ -539,9 +540,25 @@ impl ChatCompletion {
content: vec![MessageChunk::Text(Text { text: output })],
name: None,
}),
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { tool_calls }),
_ => {
todo!("Implement error for invalid tool vs chat");
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
role: "assistant".to_string(),
tool_calls,
}),
(Some(output), Some(_)) => {
warn!("Received both chat and tool call");
OutputMessage::ChatMessage(Message {
role: "assistant".into(),
content: vec![MessageChunk::Text(Text { text: output })],
name: None,
})
}
(None, None) => {
warn!("Didn't receive an answer");
OutputMessage::ChatMessage(Message {
role: "assistant".into(),
content: vec![],
name: None,
})
}
};
Self {
@ -574,6 +591,7 @@ pub(crate) struct CompletionCompleteChunk {
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
@ -594,21 +612,20 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
// TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct ToolCallDelta {
#[schema(example = "assistant")]
role: String,
tool_calls: DeltaToolCall,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
enum ChatCompletionDelta {
Chat(TextMessage),
Tool(ToolCallDelta),
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
@ -616,7 +633,7 @@ pub(crate) struct DeltaToolCall {
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
@ -634,15 +651,13 @@ impl ChatCompletionChunk {
finish_reason: Option<String>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: delta,
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
@ -650,13 +665,12 @@ impl ChatCompletionChunk {
name: None,
arguments: tool_calls[0].to_string(),
},
},
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "".to_string(),
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self {
id: String::new(),
@ -982,6 +996,8 @@ impl From<Message> for TextMessage {
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct ToolCallMessage {
#[schema(example = "assistant")]
role: String,
tool_calls: Vec<ToolCall>,
tool_call_id: String,
}
@ -1265,6 +1281,7 @@ mod tests {
},
},
]
}
]
});