fix: Adapt function call response to return a json string for arguments

Necessary to keep compatibility with openai. The usage of tgi with openai compatible libraries for function calling was broken.
This commit is contained in:
Nicolas Casademont 2025-01-24 11:47:01 +01:00 committed by drbh
parent 5eec3a8bb6
commit 9a9a763eee
4 changed files with 40 additions and 10 deletions

View File

@ -1508,6 +1508,25 @@
} }
} }
}, },
"FunctionCall": {
"type": "object",
"required": [
"name",
"arguments"
],
"properties": {
"arguments": {
"type": "string"
},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionDefinition": { "FunctionDefinition": {
"type": "object", "type": "object",
"required": [ "required": [
@ -2280,7 +2299,7 @@
], ],
"properties": { "properties": {
"function": { "function": {
"$ref": "#/components/schemas/FunctionDefinition" "$ref": "#/components/schemas/FunctionCall"
}, },
"id": { "id": {
"type": "string" "type": "string"
@ -2406,4 +2425,4 @@
"description": "Hugging Face Text Generation Inference API" "description": "Hugging Face Text Generation Inference API"
} }
] ]
} }

View File

@ -305,7 +305,7 @@ chat = client.chat_completion(
) )
print(chat.choices[0].message.tool_calls) print(chat.choices[0].message.tool_calls)
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')] # [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')]
``` ```

View File

@ -1142,6 +1142,15 @@ pub struct FunctionDefinition {
pub arguments: serde_json::Value, pub arguments: serde_json::Value,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(alias = "parameters")]
pub arguments: String,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(PartialEq))] #[cfg_attr(test, derive(PartialEq))]
pub(crate) struct Tool { pub(crate) struct Tool {
@ -1165,7 +1174,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
pub struct ToolCall { pub struct ToolCall {
pub id: String, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionDefinition, pub function: FunctionCall,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@ -1718,19 +1727,20 @@ mod tests {
tool_calls: vec![ToolCall { tool_calls: vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionCall {
description: None, description: None,
name: "myfn".to_string(), name: "myfn".to_string(),
arguments: json!({ arguments: json!({
"format": "csv" "format": "csv"
}), })
.to_string(),
}, },
}], }],
}); });
let serialized = serde_json::to_string(&message).unwrap(); let serialized = serde_json::to_string(&message).unwrap();
assert_eq!( assert_eq!(
serialized, serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"# r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
); );
} }

View File

@ -12,7 +12,6 @@ use crate::sagemaker::{
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -27,6 +26,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
}; };
use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{MessageBody, ModelInfo, ModelsInfo}; use crate::{MessageBody, ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
@ -1438,10 +1438,10 @@ pub(crate) async fn chat_completions(
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionCall {
description: None, description: None,
name, name,
arguments, arguments: arguments.to_string(),
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
@ -1574,6 +1574,7 @@ Tool,
ToolCall, ToolCall,
Function, Function,
FunctionDefinition, FunctionDefinition,
FunctionCall,
ToolChoice, ToolChoice,
ModelInfo, ModelInfo,
ChatTokenizeResponse, ChatTokenizeResponse,