diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs
index e660cc74..60f13d08 100644
--- a/router/src/infer/chat_template.rs
+++ b/router/src/infer/chat_template.rs
@@ -1189,7 +1189,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
- let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
+ let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
@@ -1227,7 +1227,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
- let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
+ let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}
diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs
index 7770cd9d..6b6099f3 100644
--- a/router/src/infer/tool_grammar.rs
+++ b/router/src/infer/tool_grammar.rs
@@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
- arguments: json!({
+ parameters: json!({
"type": "object",
"properties": {
"content": {
@@ -83,7 +83,7 @@ impl ToolGrammar {
}),
);
- if let Value::Object(args) = func.arguments {
+ if let Value::Object(args) = func.parameters {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}
diff --git a/router/src/lib.rs b/router/src/lib.rs
index 089e30df..0d828843 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -730,7 +730,7 @@ pub(crate) struct ChatCompletionChoice {
pub struct ToolCallDelta {
#[schema(example = "assistant")]
role: String,
- tool_calls: DeltaToolCall,
+ tool_calls: Vec,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
@@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
- pub function: Function,
+ pub function: FunctionCallChunk,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
-pub(crate) struct Function {
+pub(crate) struct FunctionCallChunk {
pub name: Option,
pub arguments: String,
}
@@ -760,7 +760,7 @@ impl ChatCompletionChunk {
model: String,
system_fingerprint: String,
delta: Option,
- tool_calls: Option>,
+ tool_calls: Option,
created: u64,
logprobs: Option,
finish_reason: Option,
@@ -774,15 +774,12 @@ impl ChatCompletionChunk {
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
- tool_calls: DeltaToolCall {
+ tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
- function: Function {
- name: None,
- arguments: tool_calls[0].to_string(),
- },
- },
+ function: tool_calls,
+ }],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
@@ -1138,16 +1135,12 @@ pub struct FunctionDefinition {
#[serde(default)]
pub description: Option,
pub name: String,
- #[serde(alias = "parameters")]
- pub arguments: serde_json::Value,
+ pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall {
- #[serde(default)]
- pub description: Option,
pub name: String,
- #[serde(alias = "parameters")]
pub arguments: String,
}
@@ -1728,7 +1721,6 @@ mod tests {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
- description: None,
name: "myfn".to_string(),
arguments: json!({
"format": "csv"
@@ -1740,7 +1732,7 @@ mod tests {
let serialized = serde_json::to_string(&message).unwrap();
assert_eq!(
serialized,
- r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
+ r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
);
}
diff --git a/router/src/server.rs b/router/src/server.rs
index 71e8c663..e267a951 100644
--- a/router/src/server.rs
+++ b/router/src/server.rs
@@ -13,8 +13,8 @@ use crate::sagemaker::{
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::{
- usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
- GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
+ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
+ FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
@@ -24,7 +24,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
- CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
+ CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
};
use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
@@ -1117,6 +1117,7 @@ pub(crate) async fn completions(
enum StreamState {
Buffering,
BufferTrailing,
+ Arguments,
Content { skip_close_quote: bool },
}
@@ -1126,6 +1127,7 @@ fn create_event_from_stream_token(
logprobs: bool,
stream_options: Option,
inner_using_tools: bool,
+ partial_call: Option,
system_fingerprint: String,
model_id: String,
) -> Event {
@@ -1141,7 +1143,16 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
- (None, Some(vec![stream_token.token.text.clone()]))
+ match partial_call {
+ Some(partial_call) => (None, Some(partial_call)),
+ None => (
+ None,
+ Some(FunctionCallChunk {
+ name: None,
+ arguments: stream_token.token.text.clone(),
+ }),
+ ),
+ }
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
@@ -1258,7 +1269,7 @@ pub(crate) async fn chat_completions(
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
- let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
+ let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
Ok(regex) => regex,
Err(e) => {
return Err((
@@ -1273,7 +1284,6 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
- let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
@@ -1290,30 +1300,27 @@ pub(crate) async fn chat_completions(
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
- buffer.push(stream_token);
- if let Some(captures) = function_regex.captures(&json_buffer) {
+ if let Some(captures) = function_name_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
- buffer.clear();
json_buffer.clear();
} else {
- state = StreamState::Content {
- skip_close_quote: false,
- };
- // send all the buffered messages
- for stream_token in &buffer {
- let event = create_event_from_stream_token(
- stream_token,
- logprobs,
- stream_options.clone(),
- response_as_tool,
- system_fingerprint.clone(),
- model_id.clone(),
- );
- yield Ok::(event);
- }
+ state = StreamState::Arguments;
+ let event = create_event_from_stream_token(
+ &stream_token,
+ logprobs,
+ stream_options.clone(),
+ response_as_tool,
+ Some(FunctionCallChunk {
+ name: Some(function_name),
+ arguments: "{".to_string()
+ }),
+ system_fingerprint.clone(),
+ model_id.clone(),
+ );
+ yield Ok::(event);
}
}
}
@@ -1354,12 +1361,32 @@ pub(crate) async fn chat_completions(
}));
}
// cleanup the buffers
- buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
+ StreamState::Arguments => {
+ json_buffer.push_str(&token_text.replace(" ", ""));
+
+ // If we are at the end of the json we can stop
+ let function: Result = serde_json::from_str(&json_buffer);
+ if let Ok(_) = function {
+ break;
+ }
+
+ // send the content
+ let event = create_event_from_stream_token(
+ &stream_token,
+ logprobs,
+ stream_options.clone(),
+ response_as_tool,
+ None,
+ system_fingerprint.clone(),
+ model_id.clone(),
+ );
+ yield Ok::(event);
+ }
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
@@ -1371,6 +1398,7 @@ pub(crate) async fn chat_completions(
logprobs,
stream_options.clone(),
response_as_tool,
+ None,
system_fingerprint.clone(),
model_id.clone(),
);
@@ -1439,7 +1467,6 @@ pub(crate) async fn chat_completions(
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
- description: None,
name,
arguments: arguments.to_string(),
},
@@ -1572,7 +1599,6 @@ StreamOptions,
DeltaToolCall,
Tool,
ToolCall,
-Function,
FunctionDefinition,
FunctionCall,
ToolChoice,