diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index e660cc74..60f13d08 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1189,7 +1189,7 @@ TOOL CALL ID: 0 let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -1227,7 +1227,7 @@ TOOL CALL ID: 0 let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); + let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 7770cd9d..6b6099f3 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -38,7 +38,7 @@ impl ToolGrammar { description: Some( "Open ended response with no specific tool selected".to_string(), ), - arguments: json!({ + parameters: json!({ "type": "object", "properties": { "content": { @@ -83,7 +83,7 @@ impl ToolGrammar { }), ); - if let Value::Object(args) = func.arguments { + if let Value::Object(args) = func.parameters { if let Some(Value::Object(props)) = args.get("properties") { properties.extend(props.clone()); } diff --git a/router/src/lib.rs b/router/src/lib.rs index 089e30df..0d828843 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -730,7 +730,7 @@ pub(crate) struct ChatCompletionChoice { pub struct ToolCallDelta { #[schema(example = "assistant")] role: String, - tool_calls: DeltaToolCall, + tool_calls: Vec, } #[derive(Clone, Debug, Serialize, ToSchema)] @@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall { pub index: u32, pub id: String, pub r#type: String, - pub function: Function, + pub function: FunctionCallChunk, } #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] -pub(crate) struct Function { +pub(crate) struct FunctionCallChunk { pub name: Option, pub arguments: String, } @@ -760,7 +760,7 @@ impl ChatCompletionChunk { model: String, system_fingerprint: String, delta: Option, - tool_calls: Option>, + tool_calls: Option, created: u64, logprobs: Option, finish_reason: Option, @@ -774,15 +774,12 @@ impl ChatCompletionChunk { }), (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { role: "assistant".to_string(), - tool_calls: DeltaToolCall { + tool_calls: vec![DeltaToolCall { index: 0, id: String::new(), r#type: "function".to_string(), - function: Function { - name: None, - arguments: tool_calls[0].to_string(), - }, - }, + function: tool_calls, + }], }), (None, None) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), @@ -1138,16 +1135,12 @@ pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, - #[serde(alias = "parameters")] - pub arguments: serde_json::Value, + pub parameters: serde_json::Value, } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] pub(crate) struct FunctionCall { - #[serde(default)] - pub description: Option, pub name: String, - #[serde(alias = "parameters")] pub arguments: String, } @@ -1728,7 +1721,6 @@ mod tests { id: "0".to_string(), r#type: "function".to_string(), function: FunctionCall { - description: None, name: "myfn".to_string(), arguments: json!({ "format": "csv" @@ -1740,7 +1732,7 @@ mod tests { let serialized = serde_json::to_string(&message).unwrap(); assert_eq!( serialized, - r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# + r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# ); } diff --git a/router/src/server.rs b/router/src/server.rs index 71e8c663..e267a951 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,8 +13,8 @@ use crate::sagemaker::{ use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::{ - usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, - GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, + usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk, + FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, @@ -24,7 +24,7 @@ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, - CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, + CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool, }; use crate::{ChatTokenizeResponse, FunctionCall}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; @@ -1117,6 +1117,7 @@ pub(crate) async fn completions( enum StreamState { Buffering, BufferTrailing, + Arguments, Content { skip_close_quote: bool }, } @@ -1126,6 +1127,7 @@ fn create_event_from_stream_token( logprobs: bool, stream_options: Option, inner_using_tools: bool, + partial_call: Option, system_fingerprint: String, model_id: String, ) -> Event { @@ -1141,7 +1143,16 @@ fn create_event_from_stream_token( // replace the content with the tool calls if grammar is present let (content, tool_calls) = if inner_using_tools { - (None, Some(vec![stream_token.token.text.clone()])) + match partial_call { + Some(partial_call) => (None, Some(partial_call)), + None => ( + None, + Some(FunctionCallChunk { + name: None, + arguments: stream_token.token.text.clone(), + }), + ), + } } else { let content = if !stream_token.token.special { Some(stream_token.token.text.clone()) @@ -1258,7 +1269,7 @@ pub(crate) async fn chat_completions( generate_stream_internal(infer, compute_type, Json(generate_request), span).await; // regex to match any function name - let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { + let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) { Ok(regex) => regex, Err(e) => { return Err(( @@ -1273,7 +1284,6 @@ pub(crate) async fn chat_completions( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); - let mut buffer = Vec::new(); let mut json_buffer = String::new(); let mut state = if using_tools { StreamState::Buffering @@ -1290,30 +1300,27 @@ pub(crate) async fn chat_completions( match state { StreamState::Buffering => { json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - if let Some(captures) = function_regex.captures(&json_buffer) { + if let Some(captures) = function_name_regex.captures(&json_buffer) { let function_name = captures[1].to_string(); if function_name == "no_tool" { state = StreamState::BufferTrailing; response_as_tool = false; - buffer.clear(); json_buffer.clear(); } else { - state = StreamState::Content { - skip_close_quote: false, - }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } + state = StreamState::Arguments; + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + Some(FunctionCallChunk { + name: Some(function_name), + arguments: "{".to_string() + }), + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); } } } @@ -1354,12 +1361,32 @@ pub(crate) async fn chat_completions( })); } // cleanup the buffers - buffer.clear(); json_buffer.clear(); state = StreamState::Content { skip_close_quote: true, }; } + StreamState::Arguments => { + json_buffer.push_str(&token_text.replace(" ", "")); + + // If we are at the end of the json we can stop + let function: Result = serde_json::from_str(&json_buffer); + if let Ok(_) = function { + break; + } + + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + None, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } StreamState::Content { skip_close_quote } => { if skip_close_quote && token_text.contains('"') { break; @@ -1371,6 +1398,7 @@ pub(crate) async fn chat_completions( logprobs, stream_options.clone(), response_as_tool, + None, system_fingerprint.clone(), model_id.clone(), ); @@ -1439,7 +1467,6 @@ pub(crate) async fn chat_completions( id: "0".to_string(), r#type: "function".to_string(), function: FunctionCall { - description: None, name, arguments: arguments.to_string(), }, @@ -1572,7 +1599,6 @@ StreamOptions, DeltaToolCall, Tool, ToolCall, -Function, FunctionDefinition, FunctionCall, ToolChoice,