fix: prefer tool call vector over object

This commit is contained in:
drbh 2024-03-11 13:26:06 +01:00
parent 0d72af5ab0
commit 7ad4a62458
2 changed files with 7 additions and 7 deletions

View File

@ -433,7 +433,7 @@ impl ChatCompletion {
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<ToolCall>, tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
) -> Self { ) -> Self {
Self { Self {
id: String::new(), id: String::new(),
@ -764,7 +764,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall { pub(crate) struct ChatCompletionMessageToolCall {
pub id: u32, pub id: u32,
pub r#type: String, pub r#type: String,
pub function: FunctionDefinition, pub function: FunctionDefinition,
@ -781,7 +781,7 @@ pub(crate) struct Message {
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>, pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -14,7 +14,7 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, VertexRequest, VertexResponse, CompletionRequest, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ChatCompletionMessageToolCall, ToolType, Tools};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -942,7 +942,7 @@ async fn chat_completions(
) )
})?; })?;
let tool_call = Some(ToolCall { let tool_calls = vec![ChatCompletionMessageToolCall {
id: 0, id: 0,
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
@ -963,8 +963,8 @@ async fn chat_completions(
|f| Ok(f.clone()), |f| Ok(f.clone()),
)?, )?,
}, },
}); }];
(tool_call, None) (Some(tool_calls), None)
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };