diff --git a/docs/openapi.json b/docs/openapi.json
index e5c139a9..a1df080b 100644
--- a/docs/openapi.json
+++ b/docs/openapi.json
@@ -1508,26 +1508,7 @@
}
}
},
- "FunctionCall": {
- "type": "object",
- "required": [
- "name",
- "arguments"
- ],
- "properties": {
- "arguments": {
- "type": "string"
- },
- "description": {
- "type": "string",
- "nullable": true
- },
- "name": {
- "type": "string"
- }
- }
- },
- "FunctionDefinitionDeprecated": {
+ "FunctionDefinition": {
"type": "object",
"required": [
"name",
@@ -1544,23 +1525,6 @@
}
}
},
- "FunctionDefinition": {
- "type": "object",
- "required": [
- "name",
- "parameters"
- ],
- "properties": {
- "parameters": {},
- "description": {
- "type": "string",
- "nullable": true
- },
- "name": {
- "type": "string"
- }
- }
- },
"FunctionName": {
"type": "object",
"required": [
@@ -2263,14 +2227,7 @@
],
"properties": {
"function": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/FunctionDefinition"
- },
- {
- "$ref": "#/components/schemas/FunctionDefinitionDeprecated"
- }
- ]
+ "$ref": "#/components/schemas/FunctionDefinition"
},
"type": {
"type": "string",
@@ -2287,7 +2244,7 @@
],
"properties": {
"function": {
- "$ref": "#/components/schemas/FunctionCall"
+ "$ref": "#/components/schemas/FunctionDefinition"
},
"id": {
"type": "string"
@@ -2413,4 +2370,4 @@
"description": "Hugging Face Text Generation Inference API"
}
]
-}
\ No newline at end of file
+}
diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md
index 6540cb6d..e389fbbc 100644
--- a/docs/source/basic_tutorials/using_guidance.md
+++ b/docs/source/basic_tutorials/using_guidance.md
@@ -305,7 +305,7 @@ chat = client.chat_completion(
)
print(chat.choices[0].message.tool_calls)
-# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')]
+# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]
```
diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs
index 36fde519..6a9289c0 100644
--- a/router/src/infer/chat_template.rs
+++ b/router/src/infer/chat_template.rs
@@ -903,44 +903,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
- let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
- assert_eq!(result.unwrap(), expected);
- }
-
- #[test]
- fn test_chat_template_with_default_tool_template_arguments_deprecated() {
- let ct = ChatTemplate::new(
- "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
- Some(TokenizerConfigToken::String("".to_string())),
- Some(TokenizerConfigToken::String("".to_string())),
- );
-
- // convert TextMessage to Message
- let msgs: Vec = vec![
- Message {
- name: None,
- role: "user".to_string(),
- content: MessageContent::SingleText(
- "I'd like to show off how chat templating works!".to_string(),
- ),
- },
- Message {
- name: None,
- role: "assistant".to_string(),
- content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
- },
- Message {
- name: None,
- role: "user".to_string(),
- content: MessageContent::SingleText("Just testing".to_string()),
- },
- ];
- let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
- let tools: Vec = serde_json::from_str(&tools_string).unwrap();
- let tool_prompt = "This default prompt will be used".to_string();
- let tools_and_prompt = Some((tools, tool_prompt));
- let result = ct.apply(msgs, tools_and_prompt);
- let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
+ let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":\"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\"}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
@@ -974,7 +937,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
- let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
+ let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}
diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs
index 6b6099f3..7770cd9d 100644
--- a/router/src/infer/tool_grammar.rs
+++ b/router/src/infer/tool_grammar.rs
@@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
- parameters: json!({
+ arguments: json!({
"type": "object",
"properties": {
"content": {
@@ -83,7 +83,7 @@ impl ToolGrammar {
}),
);
- if let Value::Object(args) = func.parameters {
+ if let Value::Object(args) = func.arguments {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}
diff --git a/router/src/lib.rs b/router/src/lib.rs
index e3e17d2c..89d7e8f7 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -742,11 +742,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
- pub function: FunctionCallChunk,
+ pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
-pub(crate) struct FunctionCallChunk {
+pub(crate) struct Function {
pub name: Option,
pub arguments: String,
}
@@ -757,7 +757,7 @@ impl ChatCompletionChunk {
model: String,
system_fingerprint: String,
delta: Option,
- tool_calls: Option,
+ tool_calls: Option>,
created: u64,
logprobs: Option,
finish_reason: Option,
@@ -774,7 +774,10 @@ impl ChatCompletionChunk {
index: 0,
id: String::new(),
r#type: "function".to_string(),
- function: tool_calls,
+ function: Function {
+ name: None,
+ arguments: tool_calls[0].to_string(),
+ },
}],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
@@ -1130,14 +1133,15 @@ pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option,
pub name: String,
- #[serde(alias = "arguments")]
- pub parameters: serde_json::Value,
+ #[serde(alias = "parameters", serialize_with = "serialize_as_string")]
+ pub arguments: serde_json::Value,
}
-#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
-pub(crate) struct FunctionCall {
- pub name: String,
- pub arguments: String,
+fn serialize_as_string(value: &serde_json::Value, serializer: S) -> Result
+where
+ S: serde::Serializer,
+{
+ serializer.serialize_str(&value.to_string())
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
@@ -1163,7 +1167,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
pub(crate) struct ToolCall {
pub id: String,
pub r#type: String,
- pub function: FunctionCall,
+ pub function: FunctionDefinition,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@@ -1682,19 +1686,19 @@ mod tests {
tool_calls: vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
- function: FunctionCall {
+ function: FunctionDefinition {
+ description: None,
name: "myfn".to_string(),
arguments: json!({
"format": "csv"
- })
- .to_string(),
+ }),
},
}],
});
let serialized = serde_json::to_string(&message).unwrap();
assert_eq!(
serialized,
- r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
+ r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
);
}
diff --git a/router/src/server.rs b/router/src/server.rs
index 73a84c3f..9e57af27 100644
--- a/router/src/server.rs
+++ b/router/src/server.rs
@@ -12,9 +12,10 @@ use crate::sagemaker::{
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
+use crate::ChatTokenizeResponse;
use crate::{
- usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
- FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
+ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
+ GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
@@ -24,9 +25,8 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
- CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
+ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
-use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
@@ -1117,7 +1117,6 @@ pub(crate) async fn completions(
enum StreamState {
Buffering,
BufferTrailing,
- Arguments,
Content { skip_close_quote: bool },
}
@@ -1127,7 +1126,6 @@ fn create_event_from_stream_token(
logprobs: bool,
stream_options: Option,
inner_using_tools: bool,
- partial_call: Option,
system_fingerprint: String,
model_id: String,
) -> Event {
@@ -1143,16 +1141,7 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
- match partial_call {
- Some(partial_call) => (None, Some(partial_call)),
- None => (
- None,
- Some(FunctionCallChunk {
- name: None,
- arguments: stream_token.token.text.clone(),
- }),
- ),
- }
+ (None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
@@ -1269,7 +1258,7 @@ pub(crate) async fn chat_completions(
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
- let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
+ let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
@@ -1284,6 +1273,7 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
+ let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
@@ -1300,27 +1290,30 @@ pub(crate) async fn chat_completions(
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
- if let Some(captures) = function_name_regex.captures(&json_buffer) {
+ buffer.push(stream_token);
+ if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
+ buffer.clear();
json_buffer.clear();
} else {
- state = StreamState::Arguments;
- let event = create_event_from_stream_token(
- &stream_token,
- logprobs,
- stream_options.clone(),
- response_as_tool,
- Some(FunctionCallChunk {
- name: Some(function_name),
- arguments: "{".to_string()
- }),
- system_fingerprint.clone(),
- model_id.clone(),
- );
- yield Ok::(event);
+ state = StreamState::Content {
+ skip_close_quote: false,
+ };
+ // send all the buffered messages
+ for stream_token in &buffer {
+ let event = create_event_from_stream_token(
+ stream_token,
+ logprobs,
+ stream_options.clone(),
+ response_as_tool,
+ system_fingerprint.clone(),
+ model_id.clone(),
+ );
+ yield Ok::(event);
+ }
}
}
}
@@ -1361,32 +1354,12 @@ pub(crate) async fn chat_completions(
}));
}
// cleanup the buffers
+ buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
- StreamState::Arguments => {
- json_buffer.push_str(&token_text.replace(" ", ""));
-
- // If we are at the end of the json we can stop
- let function: Result = serde_json::from_str(&json_buffer);
- if let Ok(_) = function {
- break;
- }
-
- // send the content
- let event = create_event_from_stream_token(
- &stream_token,
- logprobs,
- stream_options.clone(),
- response_as_tool,
- None,
- system_fingerprint.clone(),
- model_id.clone(),
- );
- yield Ok::(event);
- }
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
@@ -1398,7 +1371,6 @@ pub(crate) async fn chat_completions(
logprobs,
stream_options.clone(),
response_as_tool,
- None,
system_fingerprint.clone(),
model_id.clone(),
);
@@ -1466,9 +1438,10 @@ pub(crate) async fn chat_completions(
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
- function: FunctionCall {
+ function: FunctionDefinition {
+ description: None,
name,
- arguments: arguments.to_string(),
+ arguments,
},
}];
(Some(tool_calls), None)
@@ -1599,8 +1572,8 @@ StreamOptions,
DeltaToolCall,
Tool,
ToolCall,
+Function,
FunctionDefinition,
-FunctionCall,
ToolChoice,
ModelInfo,
ChatTokenizeResponse,