diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 63bc8c1b..4a8141b4 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -65,11 +65,14 @@ impl ChatTemplate { let (tools, tool_prompt) = tools_and_prompt.unwrap_or_default(); - if tools.is_some() { + if let Some(ref tools) = tools { // check if the `tools` variable is used in the template // if not, we need to append the tools to the last message let text = if self.use_default_tool_template { - format!("\n---\n{:?}\n{}", tools, tool_prompt) + match serde_json::to_string(tools) { + Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Err(e) => return Err(InferError::ToolError(e.to_string())), + } } else { // if the `tools` variable is used in the template, we just append the tool_prompt format!("\n---\n{}", tool_prompt) @@ -81,18 +84,17 @@ impl ChatTemplate { let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - return self - .template + self.template .render(ChatTemplateInputs { guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, - tools: tools, + tools, tools_prompt: None, }) - .map_err(InferError::TemplateError); + .map_err(InferError::TemplateError) } } @@ -102,7 +104,8 @@ mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, + TokenizerConfigToken, Tool, }; use minijinja::Environment; @@ -861,11 +864,12 @@ mod tests { content: MessageContent::SingleText("Just testing".to_string()), }, ]; - let tools = serde_json::json!("[]"); + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); - let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); - let result = ct.apply(None, msgs, Some(grammer_with_prompt)); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + let tools_and_prompt = Some((Some(tools), tool_prompt)); + let result = ct.apply(None, msgs, tools_and_prompt); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 48aaf682..fbf23631 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1207,6 +1207,7 @@ pub(crate) struct GenerateResponse { pub(crate) struct ChatTokenizeResponse { pub(crate) tokenize_response: TokenizeResponse, pub(crate) templated_text: String, + pub(crate) using_tools: bool, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 27f0287a..791165eb 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -146,7 +146,7 @@ async fn get_chat_tokenize( } = req; let tool_prompt = tool_prompt.unwrap_or_default(); - let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + let (inputs, _grammar, using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -206,6 +206,7 @@ async fn get_chat_tokenize( let resp = ChatTokenizeResponse { tokenize_response: TokenizeResponse(tokens), templated_text: input, + using_tools, }; Ok((HeaderMap::new(), Json(resp))) } else { @@ -1165,7 +1166,7 @@ async fn chat_completions( Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - let (inputs, grammar, tool_grammar) = prepare_chat_input( + let (inputs, grammar, using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1221,7 +1222,7 @@ async fn chat_completions( }); // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if tool_grammar.is_some() { + let (content, tool_calls) = if using_tools { (None, Some(vec![stream_token.token.text])) } else { let content = if !stream_token.token.special { @@ -1275,7 +1276,7 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if tool_grammar.is_some() { + let (tool_calls, output) = if using_tools { let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| { InferError::ToolError(format!( @@ -2539,7 +2540,7 @@ fn create_post_processor( Ok(post_processor) } -type PreparedInput = (String, Option, Option); +type PreparedInput = (String, Option, bool); fn prepare_chat_input( infer: &Infer, @@ -2558,7 +2559,7 @@ fn prepare_chat_input( if let Some(format) = response_format { let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), None)); + return Ok((inputs, Some(format), false)); } // if tools are set, apply the tool grammar and then the chat template @@ -2568,5 +2569,119 @@ fn prepare_chat_input( .map(|t| GrammarType::Json(serde_json::json!(t))); let tools_and_prompt: (Option>, String) = (tools, tool_prompt.into()); let inputs = infer.apply_chat_template(guideline, messages, Some(tools_and_prompt))?; - Ok((inputs, grammar, tool_grammar)) + Ok((inputs, grammar, tool_grammar.is_some())) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + use crate::ChatTemplateVersions; + use crate::FunctionsMap; + use crate::HubTokenizerConfig; + use crate::Properties; + use crate::TokenizerConfigToken; + use crate::Tool; + use crate::Tools; + + use serde_json::json; + + #[test] + fn test_prepare_chat_input() { + // Mock Backend to avoid network requests + struct MockBackend; + + impl Backend for MockBackend { + fn schedule( + &self, + request: crate::validation::ValidGenerateRequest, + ) -> Result< + tokio_stream::wrappers::UnboundedReceiverStream< + Result, + >, + InferError, + > { + unimplemented!() + } + fn health<'life0, 'async_trait>( + &'life0 self, + current_health: bool, + ) -> core::pin::Pin< + Box + core::marker::Send + 'async_trait>, + > + where + 'life0: 'async_trait, + Self: 'async_trait, + { + unimplemented!() + } + } + + let backend = MockBackend {}; + + let mut tokenizer_config = HubTokenizerConfig::default(); + + // mock tokenizer config values + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + backend, + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + let response_format = None; + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; + let guideline = None; + let messages = vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }]; + + let result = prepare_chat_input( + &infer, + response_format, + tools, + ToolChoice(None), + tool_prompt, + guideline, + messages, + ); + + assert!(result.is_ok()); + let (inputs, grammar, using_tools) = result.unwrap(); + assert_eq!(using_tools, true); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } }