diff --git a/docs/openapi.json b/docs/openapi.json index 9de76e47..0a40e8bc 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1508,6 +1508,25 @@ } } }, + "FunctionCall": { + "type": "object", + "required": [ + "name", + "arguments" + ], + "properties": { + "arguments": { + "type": "string" + }, + "description": { + "type": "string", + "nullable": true + }, + "name": { + "type": "string" + } + } + }, "FunctionDefinition": { "type": "object", "required": [ @@ -2280,7 +2299,7 @@ ], "properties": { "function": { - "$ref": "#/components/schemas/FunctionDefinition" + "$ref": "#/components/schemas/FunctionCall" }, "id": { "type": "string" @@ -2406,4 +2425,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} +} \ No newline at end of file diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md index e389fbbc..6540cb6d 100644 --- a/docs/source/basic_tutorials/using_guidance.md +++ b/docs/source/basic_tutorials/using_guidance.md @@ -305,7 +305,7 @@ chat = client.chat_completion( ) print(chat.choices[0].message.tool_calls) -# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')] +# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')] ``` diff --git a/router/src/lib.rs b/router/src/lib.rs index e8c875a8..089e30df 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1142,6 +1142,15 @@ pub struct FunctionDefinition { pub arguments: serde_json::Value, } +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] +pub(crate) struct FunctionCall { + #[serde(default)] + pub description: Option, + pub name: String, + #[serde(alias = "parameters")] + pub arguments: String, +} + #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[cfg_attr(test, derive(PartialEq))] pub(crate) struct Tool { @@ -1165,7 +1174,7 @@ pub(crate) struct ChatTemplateInputs<'a> { pub struct ToolCall { pub id: String, pub r#type: String, - pub function: FunctionDefinition, + pub function: FunctionCall, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -1718,19 +1727,20 @@ mod tests { tool_calls: vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), - function: FunctionDefinition { + function: FunctionCall { description: None, name: "myfn".to_string(), arguments: json!({ "format": "csv" - }), + }) + .to_string(), }, }], }); let serialized = serde_json::to_string(&message).unwrap(); assert_eq!( serialized, - r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"# + r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# ); } diff --git a/router/src/server.rs b/router/src/server.rs index e9aa4612..71e8c663 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,7 +12,6 @@ use crate::sagemaker::{ }; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; -use crate::ChatTokenizeResponse; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -27,6 +26,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; +use crate::{ChatTokenizeResponse, FunctionCall}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{MessageBody, ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; @@ -1438,10 +1438,10 @@ pub(crate) async fn chat_completions( let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), - function: FunctionDefinition { + function: FunctionCall { description: None, name, - arguments, + arguments: arguments.to_string(), }, }]; (Some(tool_calls), None) @@ -1574,6 +1574,7 @@ Tool, ToolCall, Function, FunctionDefinition, +FunctionCall, ToolChoice, ModelInfo, ChatTokenizeResponse,