fix: Adapt function call response to return a json string for arguments

Necessary to keep compatibility with openai. The usage of tgi with openai compatible libraries for function calling was broken.
This commit is contained in:
Nicolas Casademont 2025-01-24 11:47:01 +01:00 committed by drbh
parent 5eec3a8bb6
commit 9a9a763eee
4 changed files with 40 additions and 10 deletions

View File

@ -1508,6 +1508,25 @@
}
}
},
"FunctionCall": {
"type": "object",
"required": [
"name",
"arguments"
],
"properties": {
"arguments": {
"type": "string"
},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionDefinition": {
"type": "object",
"required": [
@ -2280,7 +2299,7 @@
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionDefinition"
"$ref": "#/components/schemas/FunctionCall"
},
"id": {
"type": "string"
@ -2406,4 +2425,4 @@
"description": "Hugging Face Text Generation Inference API"
}
]
}
}

View File

@ -305,7 +305,7 @@ chat = client.chat_completion(
)
print(chat.choices[0].message.tool_calls)
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')]
```

View File

@ -1142,6 +1142,15 @@ pub struct FunctionDefinition {
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(alias = "parameters")]
pub arguments: String,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) struct Tool {
@ -1165,7 +1174,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: FunctionDefinition,
pub function: FunctionCall,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@ -1718,19 +1727,20 @@ mod tests {
tool_calls: vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
function: FunctionCall {
description: None,
name: "myfn".to_string(),
arguments: json!({
"format": "csv"
}),
})
.to_string(),
},
}],
});
let serialized = serde_json::to_string(&message).unwrap();
assert_eq!(
serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"#
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
);
}

View File

@ -12,7 +12,6 @@ use crate::sagemaker::{
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -27,6 +26,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{MessageBody, ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
@ -1438,10 +1438,10 @@ pub(crate) async fn chat_completions(
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
function: FunctionCall {
description: None,
name,
arguments,
arguments: arguments.to_string(),
},
}];
(Some(tool_calls), None)
@ -1574,6 +1574,7 @@ Tool,
ToolCall,
Function,
FunctionDefinition,
FunctionCall,
ToolChoice,
ModelInfo,
ChatTokenizeResponse,