From c1c4dfb521f762c9d923d316604acf51e0178ccf Mon Sep 17 00:00:00 2001 From: Nicolas Casademont Date: Tue, 4 Feb 2025 11:07:42 +0100 Subject: [PATCH] fix: Allow back arguments in function definition and the corresponding test --- docs/openapi.json | 13 ++++++++--- router/src/infer/chat_template.rs | 37 +++++++++++++++++++++++++++++++ router/src/lib.rs | 1 + 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 0a40e8bc..46556865 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1531,10 +1531,10 @@ "type": "object", "required": [ "name", - "arguments" + "parameters" ], "properties": { - "arguments": {}, + "parameters": {}, "description": { "type": "string", "nullable": true @@ -2282,7 +2282,14 @@ ], "properties": { "function": { - "$ref": "#/components/schemas/FunctionDefinition" + "oneOf": [ + { + "$ref": "#/components/schemas/FunctionDefinition" + }, + { + "$ref": "#/components/schemas/FunctionCall" + } + ] }, "type": { "type": "string", diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 60f13d08..56a8616c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1193,6 +1193,43 @@ TOOL CALL ID: 0 assert_eq!(result.unwrap(), expected); } + #[test] + fn test_chat_template_with_default_tool_template_arguments_deprecated() { + let ct = ChatTemplate::new( + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText("Great! How can I help you today?".to_string()), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Just testing".to_string()), + }, + ]; + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); + let tool_prompt = "This default prompt will be used".to_string(); + let tools_and_prompt = Some((tools, tool_prompt)); + let result = ct.apply(msgs, tools_and_prompt); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); + assert_eq!(result.unwrap(), expected); + } + #[test] fn test_chat_template_with_custom_tool_template() { // chat template from meta-llama/Meta-Llama-3.1-8B-Instruct diff --git a/router/src/lib.rs b/router/src/lib.rs index 0d828843..1932b06b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1135,6 +1135,7 @@ pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, + #[serde(alias = "arguments")] pub parameters: serde_json::Value, }