diff --git a/router/src/lib.rs b/router/src/lib.rs index e62d3a93..5120ad88 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -11,6 +11,7 @@ use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; use tokio::sync::OwnedSemaphorePermit; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::warn; use utoipa::ToSchema; use validation::Validation; @@ -539,9 +540,25 @@ impl ChatCompletion { content: vec![MessageChunk::Text(Text { text: output })], name: None, }), - (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { tool_calls }), - _ => { - todo!("Implement error for invalid tool vs chat"); + (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { + role: "assistant".to_string(), + tool_calls, + }), + (Some(output), Some(_)) => { + warn!("Received both chat and tool call"); + OutputMessage::ChatMessage(Message { + role: "assistant".into(), + content: vec![MessageChunk::Text(Text { text: output })], + name: None, + }) + } + (None, None) => { + warn!("Didn't receive an answer"); + OutputMessage::ChatMessage(Message { + role: "assistant".into(), + content: vec![], + name: None, + }) } }; Self { @@ -574,6 +591,7 @@ pub(crate) struct CompletionCompleteChunk { pub model: String, pub system_fingerprint: String, } + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, @@ -594,21 +612,20 @@ pub(crate) struct ChatCompletionChoice { pub finish_reason: Option, } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -pub(crate) struct ChatCompletionDelta { - #[schema(example = "user")] - // TODO Modify this to a true enum. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub role: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - #[schema(example = "What is Deep Learning?")] - pub content: Option, - // default to None - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_calls: Option, +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +pub struct ToolCallDelta { + #[schema(example = "assistant")] + role: String, + tool_calls: DeltaToolCall, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +enum ChatCompletionDelta { + Chat(TextMessage), + Tool(ToolCallDelta), +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] pub(crate) struct DeltaToolCall { pub index: u32, pub id: String, @@ -616,7 +633,7 @@ pub(crate) struct DeltaToolCall { pub function: Function, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] pub(crate) struct Function { pub name: Option, pub arguments: String, @@ -634,15 +651,13 @@ impl ChatCompletionChunk { finish_reason: Option, ) -> Self { let delta = match (delta, tool_calls) { - (Some(delta), _) => ChatCompletionDelta { - role: Some("assistant".to_string()), - content: Some(delta), - tool_calls: None, - }, - (None, Some(tool_calls)) => ChatCompletionDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: Some(DeltaToolCall { + (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: delta, + }), + (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { + role: "assistant".to_string(), + tool_calls: DeltaToolCall { index: 0, id: String::new(), r#type: "function".to_string(), @@ -650,13 +665,12 @@ impl ChatCompletionChunk { name: None, arguments: tool_calls[0].to_string(), }, - }), - }, - (None, None) => ChatCompletionDelta { - role: None, - content: None, - tool_calls: None, - }, + }, + }), + (None, None) => ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: "".to_string(), + }), }; Self { id: String::new(), @@ -982,6 +996,8 @@ impl From for TextMessage { #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] pub struct ToolCallMessage { + #[schema(example = "assistant")] + role: String, tool_calls: Vec, tool_call_id: String, } @@ -1265,6 +1281,7 @@ mod tests { }, }, ] + } ] });