feat: serialize function definition with serialize_as_string

This commit is contained in:
drbh 2025-02-07 22:27:24 +00:00
parent 7d852cde78
commit aad1901aa5
6 changed files with 58 additions and 161 deletions

View File

@ -1508,26 +1508,7 @@
} }
} }
}, },
"FunctionCall": { "FunctionDefinition": {
"type": "object",
"required": [
"name",
"arguments"
],
"properties": {
"arguments": {
"type": "string"
},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionDefinitionDeprecated": {
"type": "object", "type": "object",
"required": [ "required": [
"name", "name",
@ -1544,23 +1525,6 @@
} }
} }
}, },
"FunctionDefinition": {
"type": "object",
"required": [
"name",
"parameters"
],
"properties": {
"parameters": {},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionName": { "FunctionName": {
"type": "object", "type": "object",
"required": [ "required": [
@ -2263,14 +2227,7 @@
], ],
"properties": { "properties": {
"function": { "function": {
"oneOf": [ "$ref": "#/components/schemas/FunctionDefinition"
{
"$ref": "#/components/schemas/FunctionDefinition"
},
{
"$ref": "#/components/schemas/FunctionDefinitionDeprecated"
}
]
}, },
"type": { "type": {
"type": "string", "type": "string",
@ -2287,7 +2244,7 @@
], ],
"properties": { "properties": {
"function": { "function": {
"$ref": "#/components/schemas/FunctionCall" "$ref": "#/components/schemas/FunctionDefinition"
}, },
"id": { "id": {
"type": "string" "type": "string"
@ -2413,4 +2370,4 @@
"description": "Hugging Face Text Generation Inference API" "description": "Hugging Face Text Generation Inference API"
} }
] ]
} }

View File

@ -305,7 +305,7 @@ chat = client.chat_completion(
) )
print(chat.choices[0].message.tool_calls) print(chat.choices[0].message.tool_calls)
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')] # [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]
``` ```

View File

@ -903,44 +903,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt); let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":\"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\"}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_chat_template_with_default_tool_template_arguments_deprecated() {
let ct = ChatTemplate::new(
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
// convert TextMessage to Message
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
},
Message {
name: None,
role: "assistant".to_string(),
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText("Just testing".to_string()),
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }
@ -974,7 +937,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt); let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }
} }

View File

@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some( description: Some(
"Open ended response with no specific tool selected".to_string(), "Open ended response with no specific tool selected".to_string(),
), ),
parameters: json!({ arguments: json!({
"type": "object", "type": "object",
"properties": { "properties": {
"content": { "content": {
@ -83,7 +83,7 @@ impl ToolGrammar {
}), }),
); );
if let Value::Object(args) = func.parameters { if let Value::Object(args) = func.arguments {
if let Some(Value::Object(props)) = args.get("properties") { if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone()); properties.extend(props.clone());
} }

View File

@ -742,11 +742,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32, pub index: u32,
pub id: String, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionCallChunk, pub function: Function,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct FunctionCallChunk { pub(crate) struct Function {
pub name: Option<String>, pub name: Option<String>,
pub arguments: String, pub arguments: String,
} }
@ -757,7 +757,7 @@ impl ChatCompletionChunk {
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
delta: Option<String>, delta: Option<String>,
tool_calls: Option<FunctionCallChunk>, tool_calls: Option<Vec<String>>,
created: u64, created: u64,
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
@ -774,7 +774,10 @@ impl ChatCompletionChunk {
index: 0, index: 0,
id: String::new(), id: String::new(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: tool_calls, function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}], }],
}), }),
(None, None) => ChatCompletionDelta::Chat(TextMessage { (None, None) => ChatCompletionDelta::Chat(TextMessage {
@ -1130,14 +1133,15 @@ pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
#[serde(alias = "arguments")] #[serde(alias = "parameters", serialize_with = "serialize_as_string")]
pub parameters: serde_json::Value, pub arguments: serde_json::Value,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] fn serialize_as_string<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
pub(crate) struct FunctionCall { where
pub name: String, S: serde::Serializer,
pub arguments: String, {
serializer.serialize_str(&value.to_string())
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
@ -1163,7 +1167,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
pub(crate) struct ToolCall { pub(crate) struct ToolCall {
pub id: String, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionCall, pub function: FunctionDefinition,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@ -1682,19 +1686,19 @@ mod tests {
tool_calls: vec![ToolCall { tool_calls: vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionCall { function: FunctionDefinition {
description: None,
name: "myfn".to_string(), name: "myfn".to_string(),
arguments: json!({ arguments: json!({
"format": "csv" "format": "csv"
}) }),
.to_string(),
}, },
}], }],
}); });
let serialized = serde_json::to_string(&message).unwrap(); let serialized = serde_json::to_string(&message).unwrap();
assert_eq!( assert_eq!(
serialized, serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
); );
} }

View File

@ -12,9 +12,10 @@ use crate::sagemaker::{
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
@ -24,9 +25,8 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
}; };
use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo}; use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
@ -1117,7 +1117,6 @@ pub(crate) async fn completions(
enum StreamState { enum StreamState {
Buffering, Buffering,
BufferTrailing, BufferTrailing,
Arguments,
Content { skip_close_quote: bool }, Content { skip_close_quote: bool },
} }
@ -1127,7 +1126,6 @@ fn create_event_from_stream_token(
logprobs: bool, logprobs: bool,
stream_options: Option<StreamOptions>, stream_options: Option<StreamOptions>,
inner_using_tools: bool, inner_using_tools: bool,
partial_call: Option<FunctionCallChunk>,
system_fingerprint: String, system_fingerprint: String,
model_id: String, model_id: String,
) -> Event { ) -> Event {
@ -1143,16 +1141,7 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present // replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools { let (content, tool_calls) = if inner_using_tools {
match partial_call { (None, Some(vec![stream_token.token.text.clone()]))
Some(partial_call) => (None, Some(partial_call)),
None => (
None,
Some(FunctionCallChunk {
name: None,
arguments: stream_token.token.text.clone(),
}),
),
}
} else { } else {
let content = if !stream_token.token.special { let content = if !stream_token.token.special {
Some(stream_token.token.text.clone()) Some(stream_token.token.text.clone())
@ -1269,7 +1258,7 @@ pub(crate) async fn chat_completions(
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name // regex to match any function name
let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) { let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex, Ok(regex) => regex,
Err(e) => { Err(e) => {
return Err(( return Err((
@ -1284,6 +1273,7 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new(); let mut json_buffer = String::new();
let mut state = if using_tools { let mut state = if using_tools {
StreamState::Buffering StreamState::Buffering
@ -1300,27 +1290,30 @@ pub(crate) async fn chat_completions(
match state { match state {
StreamState::Buffering => { StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
if let Some(captures) = function_name_regex.captures(&json_buffer) { buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string(); let function_name = captures[1].to_string();
if function_name == "no_tool" { if function_name == "no_tool" {
state = StreamState::BufferTrailing; state = StreamState::BufferTrailing;
response_as_tool = false; response_as_tool = false;
buffer.clear();
json_buffer.clear(); json_buffer.clear();
} else { } else {
state = StreamState::Arguments; state = StreamState::Content {
let event = create_event_from_stream_token( skip_close_quote: false,
&stream_token, };
logprobs, // send all the buffered messages
stream_options.clone(), for stream_token in &buffer {
response_as_tool, let event = create_event_from_stream_token(
Some(FunctionCallChunk { stream_token,
name: Some(function_name), logprobs,
arguments: "{".to_string() stream_options.clone(),
}), response_as_tool,
system_fingerprint.clone(), system_fingerprint.clone(),
model_id.clone(), model_id.clone(),
); );
yield Ok::<Event, Infallible>(event); yield Ok::<Event, Infallible>(event);
}
} }
} }
} }
@ -1361,32 +1354,12 @@ pub(crate) async fn chat_completions(
})); }));
} }
// cleanup the buffers // cleanup the buffers
buffer.clear();
json_buffer.clear(); json_buffer.clear();
state = StreamState::Content { state = StreamState::Content {
skip_close_quote: true, skip_close_quote: true,
}; };
} }
StreamState::Arguments => {
json_buffer.push_str(&token_text.replace(" ", ""));
// If we are at the end of the json we can stop
let function: Result<serde::de::IgnoredAny, _> = serde_json::from_str(&json_buffer);
if let Ok(_) = function {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
None,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
StreamState::Content { skip_close_quote } => { StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') { if skip_close_quote && token_text.contains('"') {
break; break;
@ -1398,7 +1371,6 @@ pub(crate) async fn chat_completions(
logprobs, logprobs,
stream_options.clone(), stream_options.clone(),
response_as_tool, response_as_tool,
None,
system_fingerprint.clone(), system_fingerprint.clone(),
model_id.clone(), model_id.clone(),
); );
@ -1466,9 +1438,10 @@ pub(crate) async fn chat_completions(
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionCall { function: FunctionDefinition {
description: None,
name, name,
arguments: arguments.to_string(), arguments,
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
@ -1599,8 +1572,8 @@ StreamOptions,
DeltaToolCall, DeltaToolCall,
Tool, Tool,
ToolCall, ToolCall,
Function,
FunctionDefinition, FunctionDefinition,
FunctionCall,
ToolChoice, ToolChoice,
ModelInfo, ModelInfo,
ChatTokenizeResponse, ChatTokenizeResponse,