From 0ca7af8830e506a0213eef46388019fb4528ad1a Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 7 Feb 2025 22:27:24 +0000 Subject: [PATCH] feat: serialize function definition with serialize_as_string --- docs/openapi.json | 51 +---------- docs/source/basic_tutorials/using_guidance.md | 2 +- router/src/infer/chat_template.rs | 41 +-------- router/src/infer/tool_grammar.rs | 4 +- router/src/lib.rs | 34 ++++---- router/src/server.rs | 87 +++++++------------ 6 files changed, 58 insertions(+), 161 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 7ffea3fa..9de76e47 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1508,26 +1508,7 @@ } } }, - "FunctionCall": { - "type": "object", - "required": [ - "name", - "arguments" - ], - "properties": { - "arguments": { - "type": "string" - }, - "description": { - "type": "string", - "nullable": true - }, - "name": { - "type": "string" - } - } - }, - "FunctionDefinitionDeprecated": { + "FunctionDefinition": { "type": "object", "required": [ "name", @@ -1544,23 +1525,6 @@ } } }, - "FunctionDefinition": { - "type": "object", - "required": [ - "name", - "parameters" - ], - "properties": { - "parameters": {}, - "description": { - "type": "string", - "nullable": true - }, - "name": { - "type": "string" - } - } - }, "FunctionName": { "type": "object", "required": [ @@ -2299,14 +2263,7 @@ ], "properties": { "function": { - "oneOf": [ - { - "$ref": "#/components/schemas/FunctionDefinition" - }, - { - "$ref": "#/components/schemas/FunctionDefinitionDeprecated" - } - ] + "$ref": "#/components/schemas/FunctionDefinition" }, "type": { "type": "string", @@ -2323,7 +2280,7 @@ ], "properties": { "function": { - "$ref": "#/components/schemas/FunctionCall" + "$ref": "#/components/schemas/FunctionDefinition" }, "id": { "type": "string" @@ -2449,4 +2406,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +} diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md index 6540cb6d..e389fbbc 100644 --- a/docs/source/basic_tutorials/using_guidance.md +++ b/docs/source/basic_tutorials/using_guidance.md @@ -305,7 +305,7 @@ chat = client.chat_completion( ) print(chat.choices[0].message.tool_calls) -# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionCall(arguments="{\"format\": \"fahrenheit\", \"location\": \"Brooklyn, New York\", \"num_days\": 7}", name='get_n_day_weather_forecast', description=None), id=0, type='function')] +# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')] ``` diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 56a8616c..b179dd4d 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1189,44 +1189,7 @@ TOOL CALL ID: 0 let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); - assert_eq!(result.unwrap(), expected); - } - - #[test] - fn test_chat_template_with_default_tool_template_arguments_deprecated() { - let ct = ChatTemplate::new( - "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(), - Some(TokenizerConfigToken::String("".to_string())), - Some(TokenizerConfigToken::String("".to_string())), - ); - - // convert TextMessage to Message - let msgs: Vec = vec![ - Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText( - "I'd like to show off how chat templating works!".to_string(), - ), - }, - Message { - name: None, - role: "assistant".to_string(), - content: MessageContent::SingleText("Great! How can I help you today?".to_string()), - }, - Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText("Just testing".to_string()), - }, - ]; - let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","arguments": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); - let tools: Vec = serde_json::from_str(&tools_string).unwrap(); - let tool_prompt = "This default prompt will be used".to_string(); - let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(msgs, tools_and_prompt); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":\"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\"}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -1264,7 +1227,7 @@ TOOL CALL ID: 0 let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); + let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 6b6099f3..7770cd9d 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -38,7 +38,7 @@ impl ToolGrammar { description: Some( "Open ended response with no specific tool selected".to_string(), ), - parameters: json!({ + arguments: json!({ "type": "object", "properties": { "content": { @@ -83,7 +83,7 @@ impl ToolGrammar { }), ); - if let Value::Object(args) = func.parameters { + if let Value::Object(args) = func.arguments { if let Some(Value::Object(props)) = args.get("properties") { properties.extend(props.clone()); } diff --git a/router/src/lib.rs b/router/src/lib.rs index 1932b06b..6d4814b1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall { pub index: u32, pub id: String, pub r#type: String, - pub function: FunctionCallChunk, + pub function: Function, } #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] -pub(crate) struct FunctionCallChunk { +pub(crate) struct Function { pub name: Option, pub arguments: String, } @@ -760,7 +760,7 @@ impl ChatCompletionChunk { model: String, system_fingerprint: String, delta: Option, - tool_calls: Option, + tool_calls: Option>, created: u64, logprobs: Option, finish_reason: Option, @@ -778,7 +778,10 @@ impl ChatCompletionChunk { index: 0, id: String::new(), r#type: "function".to_string(), - function: tool_calls, + function: Function { + name: None, + arguments: tool_calls[0].to_string(), + }, }], }), (None, None) => ChatCompletionDelta::Chat(TextMessage { @@ -1135,14 +1138,15 @@ pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, - #[serde(alias = "arguments")] - pub parameters: serde_json::Value, + #[serde(alias = "parameters", serialize_with = "serialize_as_string")] + pub arguments: serde_json::Value, } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] -pub(crate) struct FunctionCall { - pub name: String, - pub arguments: String, +fn serialize_as_string(value: &serde_json::Value, serializer: S) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str(&value.to_string()) } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] @@ -1168,7 +1172,7 @@ pub(crate) struct ChatTemplateInputs<'a> { pub struct ToolCall { pub id: String, pub r#type: String, - pub function: FunctionCall, + pub function: FunctionDefinition, } #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] @@ -1721,19 +1725,19 @@ mod tests { tool_calls: vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), - function: FunctionCall { + function: FunctionDefinition { + description: None, name: "myfn".to_string(), arguments: json!({ "format": "csv" - }) - .to_string(), + }), }, }], }); let serialized = serde_json::to_string(&message).unwrap(); assert_eq!( serialized, - r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# + r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# ); } diff --git a/router/src/server.rs b/router/src/server.rs index e267a951..e9aa4612 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -12,9 +12,10 @@ use crate::sagemaker::{ }; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; +use crate::ChatTokenizeResponse; use crate::{ - usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk, - FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, + usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, + GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, @@ -24,9 +25,8 @@ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, - CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool, + CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{ChatTokenizeResponse, FunctionCall}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{MessageBody, ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; @@ -1117,7 +1117,6 @@ pub(crate) async fn completions( enum StreamState { Buffering, BufferTrailing, - Arguments, Content { skip_close_quote: bool }, } @@ -1127,7 +1126,6 @@ fn create_event_from_stream_token( logprobs: bool, stream_options: Option, inner_using_tools: bool, - partial_call: Option, system_fingerprint: String, model_id: String, ) -> Event { @@ -1143,16 +1141,7 @@ fn create_event_from_stream_token( // replace the content with the tool calls if grammar is present let (content, tool_calls) = if inner_using_tools { - match partial_call { - Some(partial_call) => (None, Some(partial_call)), - None => ( - None, - Some(FunctionCallChunk { - name: None, - arguments: stream_token.token.text.clone(), - }), - ), - } + (None, Some(vec![stream_token.token.text.clone()])) } else { let content = if !stream_token.token.special { Some(stream_token.token.text.clone()) @@ -1269,7 +1258,7 @@ pub(crate) async fn chat_completions( generate_stream_internal(infer, compute_type, Json(generate_request), span).await; // regex to match any function name - let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) { + let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { Ok(regex) => regex, Err(e) => { return Err(( @@ -1284,6 +1273,7 @@ pub(crate) async fn chat_completions( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); + let mut buffer = Vec::new(); let mut json_buffer = String::new(); let mut state = if using_tools { StreamState::Buffering @@ -1300,27 +1290,30 @@ pub(crate) async fn chat_completions( match state { StreamState::Buffering => { json_buffer.push_str(&token_text.replace(" ", "")); - if let Some(captures) = function_name_regex.captures(&json_buffer) { + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { let function_name = captures[1].to_string(); if function_name == "no_tool" { state = StreamState::BufferTrailing; response_as_tool = false; + buffer.clear(); json_buffer.clear(); } else { - state = StreamState::Arguments; - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - Some(FunctionCallChunk { - name: Some(function_name), - arguments: "{".to_string() - }), - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } } } } @@ -1361,32 +1354,12 @@ pub(crate) async fn chat_completions( })); } // cleanup the buffers + buffer.clear(); json_buffer.clear(); state = StreamState::Content { skip_close_quote: true, }; } - StreamState::Arguments => { - json_buffer.push_str(&token_text.replace(" ", "")); - - // If we are at the end of the json we can stop - let function: Result = serde_json::from_str(&json_buffer); - if let Ok(_) = function { - break; - } - - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - None, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } StreamState::Content { skip_close_quote } => { if skip_close_quote && token_text.contains('"') { break; @@ -1398,7 +1371,6 @@ pub(crate) async fn chat_completions( logprobs, stream_options.clone(), response_as_tool, - None, system_fingerprint.clone(), model_id.clone(), ); @@ -1466,9 +1438,10 @@ pub(crate) async fn chat_completions( let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), - function: FunctionCall { + function: FunctionDefinition { + description: None, name, - arguments: arguments.to_string(), + arguments, }, }]; (Some(tool_calls), None) @@ -1599,8 +1572,8 @@ StreamOptions, DeltaToolCall, Tool, ToolCall, +Function, FunctionDefinition, -FunctionCall, ToolChoice, ModelInfo, ChatTokenizeResponse,