feat: Make streaming for tool calling behave the same as the open ai api

The streaming API for tool calling now starts when the name is parsed and then send arguments as token are generated and stops properly.
This commit is contained in:
Nicolas Casademont 2025-01-24 14:42:25 +01:00 committed by drbh
parent 9a9a763eee
commit 8542e2b746
4 changed files with 66 additions and 48 deletions

View File

@ -1189,7 +1189,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt); let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }
@ -1227,7 +1227,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt); let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }
} }

View File

@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some( description: Some(
"Open ended response with no specific tool selected".to_string(), "Open ended response with no specific tool selected".to_string(),
), ),
arguments: json!({ parameters: json!({
"type": "object", "type": "object",
"properties": { "properties": {
"content": { "content": {
@ -83,7 +83,7 @@ impl ToolGrammar {
}), }),
); );
if let Value::Object(args) = func.arguments { if let Value::Object(args) = func.parameters {
if let Some(Value::Object(props)) = args.get("properties") { if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone()); properties.extend(props.clone());
} }

View File

@ -730,7 +730,7 @@ pub(crate) struct ChatCompletionChoice {
pub struct ToolCallDelta { pub struct ToolCallDelta {
#[schema(example = "assistant")] #[schema(example = "assistant")]
role: String, role: String,
tool_calls: DeltaToolCall, tool_calls: Vec<DeltaToolCall>,
} }
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32, pub index: u32,
pub id: String, pub id: String,
pub r#type: String, pub r#type: String,
pub function: Function, pub function: FunctionCallChunk,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function { pub(crate) struct FunctionCallChunk {
pub name: Option<String>, pub name: Option<String>,
pub arguments: String, pub arguments: String,
} }
@ -760,7 +760,7 @@ impl ChatCompletionChunk {
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
delta: Option<String>, delta: Option<String>,
tool_calls: Option<Vec<String>>, tool_calls: Option<FunctionCallChunk>,
created: u64, created: u64,
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
@ -774,15 +774,12 @@ impl ChatCompletionChunk {
}), }),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(), role: "assistant".to_string(),
tool_calls: DeltaToolCall { tool_calls: vec![DeltaToolCall {
index: 0, index: 0,
id: String::new(), id: String::new(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: Function { function: tool_calls,
name: None, }],
arguments: tool_calls[0].to_string(),
},
},
}), }),
(None, None) => ChatCompletionDelta::Chat(TextMessage { (None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
@ -1138,16 +1135,12 @@ pub struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
#[serde(alias = "parameters")] pub parameters: serde_json::Value,
pub arguments: serde_json::Value,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall { pub(crate) struct FunctionCall {
#[serde(default)]
pub description: Option<String>,
pub name: String, pub name: String,
#[serde(alias = "parameters")]
pub arguments: String, pub arguments: String,
} }
@ -1728,7 +1721,6 @@ mod tests {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
description: None,
name: "myfn".to_string(), name: "myfn".to_string(),
arguments: json!({ arguments: json!({
"format": "csv" "format": "csv"
@ -1740,7 +1732,7 @@ mod tests {
let serialized = serde_json::to_string(&message).unwrap(); let serialized = serde_json::to_string(&message).unwrap();
assert_eq!( assert_eq!(
serialized, serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"# r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
); );
} }

View File

@ -13,8 +13,8 @@ use crate::sagemaker::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
@ -24,7 +24,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
}; };
use crate::{ChatTokenizeResponse, FunctionCall}; use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
@ -1117,6 +1117,7 @@ pub(crate) async fn completions(
enum StreamState { enum StreamState {
Buffering, Buffering,
BufferTrailing, BufferTrailing,
Arguments,
Content { skip_close_quote: bool }, Content { skip_close_quote: bool },
} }
@ -1126,6 +1127,7 @@ fn create_event_from_stream_token(
logprobs: bool, logprobs: bool,
stream_options: Option<StreamOptions>, stream_options: Option<StreamOptions>,
inner_using_tools: bool, inner_using_tools: bool,
partial_call: Option<FunctionCallChunk>,
system_fingerprint: String, system_fingerprint: String,
model_id: String, model_id: String,
) -> Event { ) -> Event {
@ -1141,7 +1143,16 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present // replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools { let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()])) match partial_call {
Some(partial_call) => (None, Some(partial_call)),
None => (
None,
Some(FunctionCallChunk {
name: None,
arguments: stream_token.token.text.clone(),
}),
),
}
} else { } else {
let content = if !stream_token.token.special { let content = if !stream_token.token.special {
Some(stream_token.token.text.clone()) Some(stream_token.token.text.clone())
@ -1258,7 +1269,7 @@ pub(crate) async fn chat_completions(
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name // regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
Ok(regex) => regex, Ok(regex) => regex,
Err(e) => { Err(e) => {
return Err(( return Err((
@ -1273,7 +1284,6 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! { let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new(); let mut json_buffer = String::new();
let mut state = if using_tools { let mut state = if using_tools {
StreamState::Buffering StreamState::Buffering
@ -1290,25 +1300,23 @@ pub(crate) async fn chat_completions(
match state { match state {
StreamState::Buffering => { StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token); if let Some(captures) = function_name_regex.captures(&json_buffer) {
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string(); let function_name = captures[1].to_string();
if function_name == "no_tool" { if function_name == "no_tool" {
state = StreamState::BufferTrailing; state = StreamState::BufferTrailing;
response_as_tool = false; response_as_tool = false;
buffer.clear();
json_buffer.clear(); json_buffer.clear();
} else { } else {
state = StreamState::Content { state = StreamState::Arguments;
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token( let event = create_event_from_stream_token(
stream_token, &stream_token,
logprobs, logprobs,
stream_options.clone(), stream_options.clone(),
response_as_tool, response_as_tool,
Some(FunctionCallChunk {
name: Some(function_name),
arguments: "{".to_string()
}),
system_fingerprint.clone(), system_fingerprint.clone(),
model_id.clone(), model_id.clone(),
); );
@ -1316,7 +1324,6 @@ pub(crate) async fn chat_completions(
} }
} }
} }
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes // if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => { StreamState::BufferTrailing => {
let infix_text = "\"content\":\""; let infix_text = "\"content\":\"";
@ -1354,12 +1361,32 @@ pub(crate) async fn chat_completions(
})); }));
} }
// cleanup the buffers // cleanup the buffers
buffer.clear();
json_buffer.clear(); json_buffer.clear();
state = StreamState::Content { state = StreamState::Content {
skip_close_quote: true, skip_close_quote: true,
}; };
} }
StreamState::Arguments => {
json_buffer.push_str(&token_text.replace(" ", ""));
// If we are at the end of the json we can stop
let function: Result<serde::de::IgnoredAny, _> = serde_json::from_str(&json_buffer);
if let Ok(_) = function {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
None,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
StreamState::Content { skip_close_quote } => { StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') { if skip_close_quote && token_text.contains('"') {
break; break;
@ -1371,6 +1398,7 @@ pub(crate) async fn chat_completions(
logprobs, logprobs,
stream_options.clone(), stream_options.clone(),
response_as_tool, response_as_tool,
None,
system_fingerprint.clone(), system_fingerprint.clone(),
model_id.clone(), model_id.clone(),
); );
@ -1439,7 +1467,6 @@ pub(crate) async fn chat_completions(
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
description: None,
name, name,
arguments: arguments.to_string(), arguments: arguments.to_string(),
}, },
@ -1572,7 +1599,6 @@ StreamOptions,
DeltaToolCall, DeltaToolCall,
Tool, Tool,
ToolCall, ToolCall,
Function,
FunctionDefinition, FunctionDefinition,
FunctionCall, FunctionCall,
ToolChoice, ToolChoice,