fix: adjust streaming tool response

This commit is contained in:
drbh 2025-02-11 14:22:03 +00:00
parent 5f030140be
commit dbce04e4d3
2 changed files with 28 additions and 27 deletions

View File

@ -765,6 +765,7 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
usage: Option<Usage>, usage: Option<Usage>,
tool_name: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) { let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
@ -779,7 +780,7 @@ impl ChatCompletionChunk {
id: String::new(), id: String::new(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: Function { function: Function {
name: None, name: tool_name,
arguments: tool_calls[0].to_string(), arguments: tool_calls[0].to_string(),
}, },
}], }],
@ -1364,7 +1365,7 @@ pub struct SimpleToken {
stop: usize, stop: usize,
} }
#[derive(Debug, Serialize, ToSchema, Clone)] #[derive(Debug, Serialize, ToSchema, Clone, PartialEq)]
#[serde(rename_all(serialize = "snake_case"))] #[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")] #[schema(example = "Length")]
pub enum FinishReason { pub enum FinishReason {

View File

@ -1128,6 +1128,7 @@ fn create_event_from_stream_token(
inner_using_tools: bool, inner_using_tools: bool,
system_fingerprint: String, system_fingerprint: String,
model_id: String, model_id: String,
tool_name: Option<String>,
) -> Event { ) -> Event {
let event = Event::default(); let event = Event::default();
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
@ -1141,7 +1142,9 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present // replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools { let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()])) // escape the token text so its a json string
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
(None, Some(vec![escaped_text]))
} else { } else {
let content = if !stream_token.token.special { let content = if !stream_token.token.special {
Some(stream_token.token.text.clone()) Some(stream_token.token.text.clone())
@ -1184,6 +1187,7 @@ fn create_event_from_stream_token(
logprobs, logprobs,
finish_reason, finish_reason,
usage, usage,
tool_name,
)); ));
event.json_data(chat_complete).unwrap_or_else(|e| { event.json_data(chat_complete).unwrap_or_else(|e| {
@ -1283,6 +1287,7 @@ pub(crate) async fn chat_completions(
} }
}; };
let mut response_as_tool = using_tools; let mut response_as_tool = using_tools;
let mut global_function_name = String::new();
while let Some(result) = response_stream.next().await { while let Some(result) = response_stream.next().await {
match result{ match result{
Ok(stream_token) => { Ok(stream_token) => {
@ -1292,8 +1297,8 @@ pub(crate) async fn chat_completions(
json_buffer.push_str(&token_text.replace(" ", "")); json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token); buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) { if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string(); global_function_name = captures[1].to_string();
if function_name == "no_tool" { if global_function_name == "no_tool" {
state = StreamState::BufferTrailing; state = StreamState::BufferTrailing;
response_as_tool = false; response_as_tool = false;
buffer.clear(); buffer.clear();
@ -1302,18 +1307,7 @@ pub(crate) async fn chat_completions(
state = StreamState::Content { state = StreamState::Content {
skip_close_quote: false, skip_close_quote: false,
}; };
// send all the buffered messages buffer = buffer.drain(0..1).collect();
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
} }
} }
} }
@ -1348,6 +1342,7 @@ pub(crate) async fn chat_completions(
None, None,
None, None,
None, None,
Some(global_function_name.clone()),
)); ));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into() InferError::StreamSerializationError(e.to_string()).into()
@ -1365,18 +1360,23 @@ pub(crate) async fn chat_completions(
break; break;
} }
// send the content buffer.push(stream_token);
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
for stream_token in &buffer[..buffer.len() - 2] {
let event = create_event_from_stream_token( let event = create_event_from_stream_token(
&stream_token, stream_token,
logprobs, logprobs,
stream_options.clone(), stream_options.clone(),
response_as_tool, response_as_tool,
system_fingerprint.clone(), system_fingerprint.clone(),
model_id.clone(), model_id.clone(),
Some(global_function_name.clone()),
); );
yield Ok::<Event, Infallible>(event); yield Ok::<Event, Infallible>(event);
} }
buffer = buffer.drain(buffer.len() - 2..).collect();
}
} }
} }
Err(err) => yield Ok(err.into_openai_event()) Err(err) => yield Ok(err.into_openai_event())