From 8542e2b7465b7cc83ac2269014340a409617a226 Mon Sep 17 00:00:00 2001 From: Nicolas Casademont Date: Fri, 24 Jan 2025 14:42:25 +0100 Subject: [PATCH] feat: Make streaming for tool calling behave the same as the open ai api The streaming API for tool calling now starts when the name is parsed and then send arguments as token are generated and stops properly. --- router/src/infer/chat_template.rs | 4 +- router/src/infer/tool_grammar.rs | 4 +- router/src/lib.rs | 26 ++++------ router/src/server.rs | 80 ++++++++++++++++++++----------- 4 files changed, 66 insertions(+), 48 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index e660cc74..60f13d08 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1189,7 +1189,7 @@ TOOL CALL ID: 0 let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -1227,7 +1227,7 @@ TOOL CALL ID: 0 let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); + let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 7770cd9d..6b6099f3 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -38,7 +38,7 @@ impl ToolGrammar { description: Some( "Open ended response with no specific tool selected".to_string(), ), - arguments: json!({ + parameters: json!({ "type": "object", "properties": { "content": { @@ -83,7 +83,7 @@ impl ToolGrammar { }), ); - if let Value::Object(args) = func.arguments { + if let Value::Object(args) = func.parameters { if let Some(Value::Object(props)) = args.get("properties") { properties.extend(props.clone()); } diff --git a/router/src/lib.rs b/router/src/lib.rs index 089e30df..0d828843 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -730,7 +730,7 @@ pub(crate) struct ChatCompletionChoice { pub struct ToolCallDelta { #[schema(example = "assistant")] role: String, - tool_calls: DeltaToolCall, + tool_calls: Vec, } #[derive(Clone, Debug, Serialize, ToSchema)] @@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall { pub index: u32, pub id: String, pub r#type: String, - pub function: Function, + pub function: FunctionCallChunk, } #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] -pub(crate) struct Function { +pub(crate) struct FunctionCallChunk { pub name: Option, pub arguments: String, } @@ -760,7 +760,7 @@ impl ChatCompletionChunk { model: String, system_fingerprint: String, delta: Option, - tool_calls: Option>, + tool_calls: Option, created: u64, logprobs: Option, finish_reason: Option, @@ -774,15 +774,12 @@ impl ChatCompletionChunk { }), (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { role: "assistant".to_string(), - tool_calls: DeltaToolCall { + tool_calls: vec![DeltaToolCall { index: 0, id: String::new(), r#type: "function".to_string(), - function: Function { - name: None, - arguments: tool_calls[0].to_string(), - }, - }, + function: tool_calls, + }], }), (None, None) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), @@ -1138,16 +1135,12 @@ pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, - #[serde(alias = "parameters")] - pub arguments: serde_json::Value, + pub parameters: serde_json::Value, } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] pub(crate) struct FunctionCall { - #[serde(default)] - pub description: Option, pub name: String, - #[serde(alias = "parameters")] pub arguments: String, } @@ -1728,7 +1721,6 @@ mod tests { id: "0".to_string(), r#type: "function".to_string(), function: FunctionCall { - description: None, name: "myfn".to_string(), arguments: json!({ "format": "csv" @@ -1740,7 +1732,7 @@ mod tests { let serialized = serde_json::to_string(&message).unwrap(); assert_eq!( serialized, - r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# + r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# ); } diff --git a/router/src/server.rs b/router/src/server.rs index 71e8c663..e267a951 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,8 +13,8 @@ use crate::sagemaker::{ use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::{ - usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, - GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, + usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk, + FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, @@ -24,7 +24,7 @@ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, - CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, + CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool, }; use crate::{ChatTokenizeResponse, FunctionCall}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; @@ -1117,6 +1117,7 @@ pub(crate) async fn completions( enum StreamState { Buffering, BufferTrailing, + Arguments, Content { skip_close_quote: bool }, } @@ -1126,6 +1127,7 @@ fn create_event_from_stream_token( logprobs: bool, stream_options: Option, inner_using_tools: bool, + partial_call: Option, system_fingerprint: String, model_id: String, ) -> Event { @@ -1141,7 +1143,16 @@ fn create_event_from_stream_token( // replace the content with the tool calls if grammar is present let (content, tool_calls) = if inner_using_tools { - (None, Some(vec![stream_token.token.text.clone()])) + match partial_call { + Some(partial_call) => (None, Some(partial_call)), + None => ( + None, + Some(FunctionCallChunk { + name: None, + arguments: stream_token.token.text.clone(), + }), + ), + } } else { let content = if !stream_token.token.special { Some(stream_token.token.text.clone()) @@ -1258,7 +1269,7 @@ pub(crate) async fn chat_completions( generate_stream_internal(infer, compute_type, Json(generate_request), span).await; // regex to match any function name - let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { + let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) { Ok(regex) => regex, Err(e) => { return Err(( @@ -1273,7 +1284,6 @@ pub(crate) async fn chat_completions( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); - let mut buffer = Vec::new(); let mut json_buffer = String::new(); let mut state = if using_tools { StreamState::Buffering @@ -1290,30 +1300,27 @@ pub(crate) async fn chat_completions( match state { StreamState::Buffering => { json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - if let Some(captures) = function_regex.captures(&json_buffer) { + if let Some(captures) = function_name_regex.captures(&json_buffer) { let function_name = captures[1].to_string(); if function_name == "no_tool" { state = StreamState::BufferTrailing; response_as_tool = false; - buffer.clear(); json_buffer.clear(); } else { - state = StreamState::Content { - skip_close_quote: false, - }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } + state = StreamState::Arguments; + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + Some(FunctionCallChunk { + name: Some(function_name), + arguments: "{".to_string() + }), + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); } } } @@ -1354,12 +1361,32 @@ pub(crate) async fn chat_completions( })); } // cleanup the buffers - buffer.clear(); json_buffer.clear(); state = StreamState::Content { skip_close_quote: true, }; } + StreamState::Arguments => { + json_buffer.push_str(&token_text.replace(" ", "")); + + // If we are at the end of the json we can stop + let function: Result = serde_json::from_str(&json_buffer); + if let Ok(_) = function { + break; + } + + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + None, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } StreamState::Content { skip_close_quote } => { if skip_close_quote && token_text.contains('"') { break; @@ -1371,6 +1398,7 @@ pub(crate) async fn chat_completions( logprobs, stream_options.clone(), response_as_tool, + None, system_fingerprint.clone(), model_id.clone(), ); @@ -1439,7 +1467,6 @@ pub(crate) async fn chat_completions( id: "0".to_string(), r#type: "function".to_string(), function: FunctionCall { - description: None, name, arguments: arguments.to_string(), }, @@ -1572,7 +1599,6 @@ StreamOptions, DeltaToolCall, Tool, ToolCall, -Function, FunctionDefinition, FunctionCall, ToolChoice,