feat: Make streaming for tool calling behave the same as the open ai api

The streaming API for tool calling now starts when the name is parsed and then send arguments as token are generated and stops properly.
This commit is contained in:
Nicolas Casademont 2025-01-24 14:42:25 +01:00 committed by drbh
parent 9a9a763eee
commit 8542e2b746
4 changed files with 66 additions and 48 deletions

View File

@ -1189,7 +1189,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
@ -1227,7 +1227,7 @@ TOOL CALL ID: 0
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}

View File

@ -38,7 +38,7 @@ impl ToolGrammar {
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
arguments: json!({
parameters: json!({
"type": "object",
"properties": {
"content": {
@ -83,7 +83,7 @@ impl ToolGrammar {
}),
);
if let Value::Object(args) = func.arguments {
if let Value::Object(args) = func.parameters {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
}

View File

@ -730,7 +730,7 @@ pub(crate) struct ChatCompletionChoice {
pub struct ToolCallDelta {
#[schema(example = "assistant")]
role: String,
tool_calls: DeltaToolCall,
tool_calls: Vec<DeltaToolCall>,
}
#[derive(Clone, Debug, Serialize, ToSchema)]
@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: Function,
pub function: FunctionCallChunk,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function {
pub(crate) struct FunctionCallChunk {
pub name: Option<String>,
pub arguments: String,
}
@ -760,7 +760,7 @@ impl ChatCompletionChunk {
model: String,
system_fingerprint: String,
delta: Option<String>,
tool_calls: Option<Vec<String>>,
tool_calls: Option<FunctionCallChunk>,
created: u64,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
@ -774,15 +774,12 @@ impl ChatCompletionChunk {
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: DeltaToolCall {
tool_calls: vec![DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tool_calls[0].to_string(),
},
},
function: tool_calls,
}],
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
@ -1138,16 +1135,12 @@ pub struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(alias = "parameters")]
pub arguments: serde_json::Value,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionCall {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(alias = "parameters")]
pub arguments: String,
}
@ -1728,7 +1721,6 @@ mod tests {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
description: None,
name: "myfn".to_string(),
arguments: json!({
"format": "csv"
@ -1740,7 +1732,7 @@ mod tests {
let serialized = serde_json::to_string(&message).unwrap();
assert_eq!(
serialized,
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
);
}

View File

@ -13,8 +13,8 @@ use crate::sagemaker::{
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
@ -24,7 +24,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
};
use crate::{ChatTokenizeResponse, FunctionCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
@ -1117,6 +1117,7 @@ pub(crate) async fn completions(
enum StreamState {
Buffering,
BufferTrailing,
Arguments,
Content { skip_close_quote: bool },
}
@ -1126,6 +1127,7 @@ fn create_event_from_stream_token(
logprobs: bool,
stream_options: Option<StreamOptions>,
inner_using_tools: bool,
partial_call: Option<FunctionCallChunk>,
system_fingerprint: String,
model_id: String,
) -> Event {
@ -1141,7 +1143,16 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
match partial_call {
Some(partial_call) => (None, Some(partial_call)),
None => (
None,
Some(FunctionCallChunk {
name: None,
arguments: stream_token.token.text.clone(),
}),
),
}
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
@ -1258,7 +1269,7 @@ pub(crate) async fn chat_completions(
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
// regex to match any function name
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
Ok(regex) => regex,
Err(e) => {
return Err((
@ -1273,7 +1284,6 @@ pub(crate) async fn chat_completions(
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut buffer = Vec::new();
let mut json_buffer = String::new();
let mut state = if using_tools {
StreamState::Buffering
@ -1290,30 +1300,27 @@ pub(crate) async fn chat_completions(
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
if let Some(captures) = function_name_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
state = StreamState::Arguments;
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
Some(FunctionCallChunk {
name: Some(function_name),
arguments: "{".to_string()
}),
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
@ -1354,12 +1361,32 @@ pub(crate) async fn chat_completions(
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Arguments => {
json_buffer.push_str(&token_text.replace(" ", ""));
// If we are at the end of the json we can stop
let function: Result<serde::de::IgnoredAny, _> = serde_json::from_str(&json_buffer);
if let Ok(_) = function {
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
None,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
@ -1371,6 +1398,7 @@ pub(crate) async fn chat_completions(
logprobs,
stream_options.clone(),
response_as_tool,
None,
system_fingerprint.clone(),
model_id.clone(),
);
@ -1439,7 +1467,6 @@ pub(crate) async fn chat_completions(
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionCall {
description: None,
name,
arguments: arguments.to_string(),
},
@ -1572,7 +1599,6 @@ StreamOptions,
DeltaToolCall,
Tool,
ToolCall,
Function,
FunctionDefinition,
FunctionCall,
ToolChoice,