feat: serialize function definition with serialize_as_string

This commit is contained in:
drbh 2025-02-07 22:27:24 +00:00
parent 983b9675d6
commit 0ca7af8830
6 changed files with 58 additions and 161 deletions

View File

@ -1508,26 +1508,7 @@
}
}
},
"FunctionCall": {
"type": "object",
"required": [
"name",
"arguments"
],
"properties": {
"arguments": {
"type": "string"
},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionDefinitionDeprecated": {
"FunctionDefinition": {
"type": "object",
"required": [
"name",
@ -1544,23 +1525,6 @@
}
}
},
"FunctionDefinition": {
"type": "object",
"required": [
"name",
"parameters"
],
"properties": {
"parameters": {},
"description": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
}
},
"FunctionName": {
"type": "object",
"required": [
@ -2299,14 +2263,7 @@
],
"properties": {
"function": {
"oneOf": [
{
"$ref": "#/components/schemas/FunctionDefinition"
},
{
"$ref": "#/components/schemas/FunctionDefinitionDeprecated"
}
]
"$ref": "#/components/schemas/FunctionDefinition"
},
"type": {
"type": "string",
@ -2323,7 +2280,7 @@
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionCall"
"$ref": "#/components/schemas/FunctionDefinition"
},
"id": {
"type": "string"
@ -2449,4 +2406,4 @@
"description": "Hugging Face Text Generation Inference API"
}
]
}
}

View File

@ -305,7 +305,7 @@ chat = client.chat_completion(
)
print(chat.choices[0].message.tool_calls)
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')]
# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]
```

View File

@ -1189,44 +1189,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_chat_template_with_default_tool_template_arguments_deprecated() {
let ct = ChatTemplate::new(
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
// convert TextMessage to Message
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
},
Message {
name: None,
role: "assistant".to_string(),
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText("Just testing".to_string()),
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":\"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\"}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
@ -1264,7 +1227,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}

View File

@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
parameters: json!({
arguments: json!({
"type": "object",
"properties": {
"content": {
@ -83,7 +83,7 @@ impl ToolGrammar {
}),
);
if let Value::Object(args) = func.parameters {
if let Value::Object(args) = func.arguments {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}

View File

@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: FunctionCallChunk,
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct FunctionCallChunk {
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
}
@ -760,7 +760,7 @@ impl ChatCompletionChunk {
model: String,
system_fingerprint: String,
delta: Option<String>,
tool_calls: Option<FunctionCallChunk>,
tool_calls: Option<Vec<String>>,
created: u64,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
@ -778,7 +778,10 @@ impl ChatCompletionChunk {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: tool_calls,
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
}],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
@ -1135,14 +1138,15 @@ pub struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(alias = "arguments")]
pub parameters: serde_json::Value,
#[serde(alias = "parameters", serialize_with = "serialize_as_string")]
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall {
pub name: String,
pub arguments: String,
fn serialize_as_string<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&value.to_string())
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
@ -1168,7 +1172,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: FunctionCall,
pub function: FunctionDefinition,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@ -1721,19 +1725,19 @@ mod tests {
tool_calls: vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
function: FunctionDefinition {
description: None,
name: "myfn".to_string(),
arguments: json!({
"format": "csv"
})
.to_string(),
}),
},
}],
});
let serialized = serde_json::to_string(&message).unwrap();
assert_eq!(
serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
);
}

View File

@ -12,9 +12,10 @@ use crate::sagemaker::{
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
@ -24,9 +25,8 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{MessageBody, ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
@ -1117,7 +1117,6 @@ pub(crate) async fn completions(
enum StreamState {
Buffering,
BufferTrailing,
Arguments,
Content { skip_close_quote: bool },
}
@ -1127,7 +1126,6 @@ fn create_event_from_stream_token(
logprobs: bool,
stream_options: Option<StreamOptions>,
inner_using_tools: bool,
partial_call: Option<FunctionCallChunk>,
system_fingerprint: String,
model_id: String,
) -> Event {
@ -1143,16 +1141,7 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
match partial_call {
Some(partial_call) => (None, Some(partial_call)),
None => (
None,
Some(FunctionCallChunk {
name: None,
arguments: stream_token.token.text.clone(),
}),
),
}
(None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
@ -1269,7 +1258,7 @@ pub(crate) async fn chat_completions(
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
Ok(regex) => regex,
Err(e) => {
return Err((
@ -1284,6 +1273,7 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
@ -1300,27 +1290,30 @@ pub(crate) async fn chat_completions(
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
if let Some(captures) = function_name_regex.captures(&json_buffer) {
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Arguments;
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
Some(FunctionCallChunk {
name: Some(function_name),
arguments: "{".to_string()
}),
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
@ -1361,32 +1354,12 @@ pub(crate) async fn chat_completions(
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Arguments => {
json_buffer.push_str(&token_text.replace(" ", ""));
// If we are at the end of the json we can stop
let function: Result<serde::de::IgnoredAny, _> = serde_json::from_str(&json_buffer);
if let Ok(_) = function {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
None,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
@ -1398,7 +1371,6 @@ pub(crate) async fn chat_completions(
logprobs,
stream_options.clone(),
response_as_tool,
None,
system_fingerprint.clone(),
model_id.clone(),
);
@ -1466,9 +1438,10 @@ pub(crate) async fn chat_completions(
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
function: FunctionDefinition {
description: None,
name,
arguments: arguments.to_string(),
arguments,
},
}];
(Some(tool_calls), None)
@ -1599,8 +1572,8 @@ StreamOptions,
DeltaToolCall,
Tool,
ToolCall,
Function,
FunctionDefinition,
FunctionCall,
ToolChoice,
ModelInfo,
ChatTokenizeResponse,