From af78f46c3df8cc9da71f1f788a88aaa0a91ef82a Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 13 Mar 2025 19:16:47 +0000 Subject: [PATCH] feat: align function id with tool call response --- router/src/chat.rs | 1 + router/src/infer/chat_template.rs | 20 +++++++++++------ router/src/infer/tool_grammar.rs | 1 + router/src/lib.rs | 37 ++++++++++++++++++++++++------- router/src/server.rs | 29 +++++++++++++++--------- router/src/vertex.rs | 3 +-- 6 files changed, 63 insertions(+), 28 deletions(-) diff --git a/router/src/chat.rs b/router/src/chat.rs index d5824fea0..0aeb868d3 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -49,6 +49,7 @@ pub(crate) fn parse_output(generated_text: &str) -> Result = messages.into_iter().map(|c| c.into()).collect(); let final_message = messages.last().cloned(); + let template_inputs = ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools, + }; + + // NOTE: initalizing `template_inputs` is helpful when JSON dumping the + // `ChatTemplateInputs` struct for debugging + // let template_inputs_as_json = serde_json::to_string(&template_inputs).unwrap(); + let mut rendered_template = self .template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools, - }) + .render(template_inputs) .map_err(InferError::TemplateError)?; // if the last message is from the assistant, continue the generation prompt diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index e4e208598..168df02cc 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -34,6 +34,7 @@ impl ToolGrammar { .chain(std::iter::once(Tool { r#type: "function".to_string(), function: FunctionDefinition { + id: None, name: "no_tool".to_string(), description: Some( "Open ended response with no specific tool selected".to_string(), diff --git a/router/src/lib.rs b/router/src/lib.rs index e8b8f6632..8c37392c6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError}; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; @@ -912,7 +913,10 @@ pub(crate) struct ChatRequest { } impl ChatRequest { - fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> { + fn try_into_generate( + self, + infer: &Infer, + ) -> Result<(GenerateRequest, Option>), InferError> { let ChatRequest { model, max_tokens, @@ -952,7 +956,7 @@ impl ChatRequest { let (inputs, grammar, using_tools) = match response_format { Some(format) => { let inputs = infer.apply_chat_template(messages, None)?; - (inputs, Some(format), false) + (inputs, Some(format), None) } None => { if let Some(tools) = tools { @@ -961,20 +965,31 @@ impl ChatRequest { let grammar = GrammarType::Json(serde_json::json!(tool_schema)); let inputs: String = infer.apply_chat_template( messages, - Some((updated_tools, tool_prompt)), + Some((updated_tools.clone(), tool_prompt)), )?; - (inputs, Some(grammar), true) + let tool_name_to_id: HashMap = updated_tools + .into_iter() + .map(|tool| { + ( + tool.function.name, + tool.function + .id + .map_or_else(|| "0".to_string(), |id| id.to_string()), + ) + }) + .collect(); + (inputs, Some(grammar), Some(tool_name_to_id)) } None => { // same as if no response_format or tools are set let inputs = infer.apply_chat_template(messages, None)?; - (inputs, None, false) + (inputs, None, None) } } } else { // if no response_format or tools are set simply apply the chat template to generate inputs let inputs = infer.apply_chat_template(messages, None)?; - (inputs, None, false) + (inputs, None, None) } } }; @@ -1154,6 +1169,8 @@ pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, #[serde(alias = "parameters", serialize_with = "serialize_as_string")] pub arguments: serde_json::Value, } @@ -1175,7 +1192,7 @@ pub(crate) struct Tool { pub function: FunctionDefinition, } -#[derive(Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, @@ -1208,6 +1225,9 @@ pub enum MessageChunk { pub struct Message { #[schema(example = "user")] pub role: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = "10")] + pub tool_call_id: Option, #[serde(flatten)] #[schema(example = "My name is David and I")] pub body: MessageBody, @@ -1287,7 +1307,7 @@ impl From for TextMessage { .collect::>() .join(""), }, - ..Default::default() + tool_call_id: value.tool_call_id, } } } @@ -1758,6 +1778,7 @@ mod tests { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { + id: None, description: None, name: "myfn".to_string(), arguments: json!({ diff --git a/router/src/server.rs b/router/src/server.rs index 45d2b9f3c..e04450b6b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1165,8 +1165,7 @@ pub(crate) async fn chat_completions( tracing::debug!("Got chat_template {:?}", infer.chat_template); let id = chat.next_tool_call_id(); - let (generate_request, using_tools): (GenerateRequest, bool) = - chat.clone().try_into_generate(&infer)?; + let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); let logprobs = logprobs.unwrap_or_default(); @@ -1188,7 +1187,7 @@ pub(crate) async fn chat_completions( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); - let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); + let mut state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); while let Some(result) = response_stream.next().await { match result{ Ok(stream_token) => { @@ -1197,12 +1196,12 @@ pub(crate) async fn chat_completions( ChatEvent::NoTool => { chat.tools = None; chat.response_format = None; - let (generate_request, using_tools): (GenerateRequest, bool) = + let (generate_request, using_tools) = chat.clone().try_into_generate(&infer).unwrap(); - assert!(!using_tools); + assert!(using_tools.is_none()); let (_headers, response_stream2) = generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; - state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); + state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); response_stream = Box::pin(response_stream2); } ChatEvent::Events(events) => { @@ -1237,14 +1236,13 @@ pub(crate) async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if using_tools { + let (tool_calls, output) = if using_tools.is_some() { match crate::chat::parse_output(&generation.generated_text)? { ChatChoice::NoTool => { chat.tools = None; chat.response_format = None; - let (generate_request, using_tools): (GenerateRequest, bool) = - chat.clone().try_into_generate(&infer)?; - assert!(!using_tools); + let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?; + assert!(using_tools.is_none()); let (headers_final, input_length_final, Json(generation)) = generate_internal( Extension(infer), compute_type, @@ -1256,7 +1254,16 @@ pub(crate) async fn chat_completions( input_length = input_length_final; (None, Some(generation.generated_text)) } - ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None), + ChatChoice::ToolCalls(mut tool_calls) => { + // assign the tool ids based on the tool names + tool_calls.iter_mut().for_each(|tool_call| { + tool_call.id = using_tools + .as_ref() + .and_then(|tools| tools.get(&tool_call.function.name)) + .map_or("0".to_string(), |id| id.clone()); + }); + (Some(tool_calls), None) + } } } else { (None, Some(generation.generated_text)) diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 38695532c..e2ae9721b 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -104,8 +104,7 @@ pub(crate) async fn vertex_compatibility( }, }, VertexInstance::Chat(instance) => { - let (generate_request, _using_tools): (GenerateRequest, bool) = - instance.try_into_generate(&infer)?; + let (generate_request, _using_tools) = instance.try_into_generate(&infer)?; generate_request } };