mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
feat: Make streaming for tool calling behave the same as the open ai api
The streaming API for tool calling now starts when the name is parsed and then send arguments as token are generated and stops properly.
This commit is contained in:
parent
9a9a763eee
commit
8542e2b746
@ -1189,7 +1189,7 @@ TOOL CALL ID: 0
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(msgs, tools_and_prompt);
|
||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"parameters\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
||||
@ -1227,7 +1227,7 @@ TOOL CALL ID: 0
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(msgs, tools_and_prompt);
|
||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\",\n \"parameters\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n }\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ impl ToolGrammar {
|
||||
description: Some(
|
||||
"Open ended response with no specific tool selected".to_string(),
|
||||
),
|
||||
arguments: json!({
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
@ -83,7 +83,7 @@ impl ToolGrammar {
|
||||
}),
|
||||
);
|
||||
|
||||
if let Value::Object(args) = func.arguments {
|
||||
if let Value::Object(args) = func.parameters {
|
||||
if let Some(Value::Object(props)) = args.get("properties") {
|
||||
properties.extend(props.clone());
|
||||
}
|
||||
|
@ -730,7 +730,7 @@ pub(crate) struct ChatCompletionChoice {
|
||||
pub struct ToolCallDelta {
|
||||
#[schema(example = "assistant")]
|
||||
role: String,
|
||||
tool_calls: DeltaToolCall,
|
||||
tool_calls: Vec<DeltaToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
@ -745,11 +745,11 @@ pub(crate) struct DeltaToolCall {
|
||||
pub index: u32,
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
pub function: FunctionCallChunk,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
pub(crate) struct Function {
|
||||
pub(crate) struct FunctionCallChunk {
|
||||
pub name: Option<String>,
|
||||
pub arguments: String,
|
||||
}
|
||||
@ -760,7 +760,7 @@ impl ChatCompletionChunk {
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
delta: Option<String>,
|
||||
tool_calls: Option<Vec<String>>,
|
||||
tool_calls: Option<FunctionCallChunk>,
|
||||
created: u64,
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
finish_reason: Option<String>,
|
||||
@ -774,15 +774,12 @@ impl ChatCompletionChunk {
|
||||
}),
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: DeltaToolCall {
|
||||
tool_calls: vec![DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
},
|
||||
},
|
||||
function: tool_calls,
|
||||
}],
|
||||
}),
|
||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
@ -1138,16 +1135,12 @@ pub struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(alias = "parameters")]
|
||||
pub arguments: serde_json::Value,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
||||
pub(crate) struct FunctionCall {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(alias = "parameters")]
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
@ -1728,7 +1721,6 @@ mod tests {
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
description: None,
|
||||
name: "myfn".to_string(),
|
||||
arguments: json!({
|
||||
"format": "csv"
|
||||
@ -1740,7 +1732,7 @@ mod tests {
|
||||
let serialized = serde_json::to_string(&message).unwrap();
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
|
||||
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"name":"myfn","arguments":"{\"format\":\"csv\"}"}}]}"#
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,8 @@ use crate::sagemaker::{
|
||||
use crate::validation::ValidationError;
|
||||
use crate::vertex::vertex_compatibility;
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionCallChunk,
|
||||
FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
||||
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
||||
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
|
||||
@ -24,7 +24,7 @@ use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Prompt, Tool,
|
||||
};
|
||||
use crate::{ChatTokenizeResponse, FunctionCall};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||
@ -1117,6 +1117,7 @@ pub(crate) async fn completions(
|
||||
enum StreamState {
|
||||
Buffering,
|
||||
BufferTrailing,
|
||||
Arguments,
|
||||
Content { skip_close_quote: bool },
|
||||
}
|
||||
|
||||
@ -1126,6 +1127,7 @@ fn create_event_from_stream_token(
|
||||
logprobs: bool,
|
||||
stream_options: Option<StreamOptions>,
|
||||
inner_using_tools: bool,
|
||||
partial_call: Option<FunctionCallChunk>,
|
||||
system_fingerprint: String,
|
||||
model_id: String,
|
||||
) -> Event {
|
||||
@ -1141,7 +1143,16 @@ fn create_event_from_stream_token(
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if inner_using_tools {
|
||||
(None, Some(vec![stream_token.token.text.clone()]))
|
||||
match partial_call {
|
||||
Some(partial_call) => (None, Some(partial_call)),
|
||||
None => (
|
||||
None,
|
||||
Some(FunctionCallChunk {
|
||||
name: None,
|
||||
arguments: stream_token.token.text.clone(),
|
||||
}),
|
||||
),
|
||||
}
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
Some(stream_token.token.text.clone())
|
||||
@ -1258,7 +1269,7 @@ pub(crate) async fn chat_completions(
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
|
||||
// regex to match any function name
|
||||
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||
let function_name_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)","#) {
|
||||
Ok(regex) => regex,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
@ -1273,7 +1284,6 @@ pub(crate) async fn chat_completions(
|
||||
|
||||
let response_stream = async_stream::stream! {
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
let mut buffer = Vec::new();
|
||||
let mut json_buffer = String::new();
|
||||
let mut state = if using_tools {
|
||||
StreamState::Buffering
|
||||
@ -1290,30 +1300,27 @@ pub(crate) async fn chat_completions(
|
||||
match state {
|
||||
StreamState::Buffering => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
buffer.push(stream_token);
|
||||
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||
if let Some(captures) = function_name_regex.captures(&json_buffer) {
|
||||
let function_name = captures[1].to_string();
|
||||
if function_name == "no_tool" {
|
||||
state = StreamState::BufferTrailing;
|
||||
response_as_tool = false;
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
} else {
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: false,
|
||||
};
|
||||
// send all the buffered messages
|
||||
for stream_token in &buffer {
|
||||
let event = create_event_from_stream_token(
|
||||
stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
state = StreamState::Arguments;
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
Some(FunctionCallChunk {
|
||||
name: Some(function_name),
|
||||
arguments: "{".to_string()
|
||||
}),
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1354,12 +1361,32 @@ pub(crate) async fn chat_completions(
|
||||
}));
|
||||
}
|
||||
// cleanup the buffers
|
||||
buffer.clear();
|
||||
json_buffer.clear();
|
||||
state = StreamState::Content {
|
||||
skip_close_quote: true,
|
||||
};
|
||||
}
|
||||
StreamState::Arguments => {
|
||||
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||
|
||||
// If we are at the end of the json we can stop
|
||||
let function: Result<serde::de::IgnoredAny, _> = serde_json::from_str(&json_buffer);
|
||||
if let Ok(_) = function {
|
||||
break;
|
||||
}
|
||||
|
||||
// send the content
|
||||
let event = create_event_from_stream_token(
|
||||
&stream_token,
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
None,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
StreamState::Content { skip_close_quote } => {
|
||||
if skip_close_quote && token_text.contains('"') {
|
||||
break;
|
||||
@ -1371,6 +1398,7 @@ pub(crate) async fn chat_completions(
|
||||
logprobs,
|
||||
stream_options.clone(),
|
||||
response_as_tool,
|
||||
None,
|
||||
system_fingerprint.clone(),
|
||||
model_id.clone(),
|
||||
);
|
||||
@ -1439,7 +1467,6 @@ pub(crate) async fn chat_completions(
|
||||
id: "0".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
description: None,
|
||||
name,
|
||||
arguments: arguments.to_string(),
|
||||
},
|
||||
@ -1572,7 +1599,6 @@ StreamOptions,
|
||||
DeltaToolCall,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Function,
|
||||
FunctionDefinition,
|
||||
FunctionCall,
|
||||
ToolChoice,
|
||||
|
Loading…
Reference in New Issue
Block a user