fix: adjust streaming tool response

This commit is contained in:
drbh 2025-02-11 14:22:03 +00:00
parent 5f030140be
commit dbce04e4d3
2 changed files with 28 additions and 27 deletions

View File

@ -765,6 +765,7 @@ impl ChatCompletionChunk {
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
usage: Option<Usage>,
tool_name: Option<String>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
@ -779,7 +780,7 @@ impl ChatCompletionChunk {
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
name: tool_name,
arguments: tool_calls[0].to_string(),
},
}],
@ -1364,7 +1365,7 @@ pub struct SimpleToken {
stop: usize,
}
#[derive(Debug, Serialize, ToSchema, Clone)]
#[derive(Debug, Serialize, ToSchema, Clone, PartialEq)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub enum FinishReason {

View File

@ -1128,6 +1128,7 @@ fn create_event_from_stream_token(
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
tool_name: Option<String>,
) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now()
@ -1141,7 +1142,9 @@ fn create_event_from_stream_token(
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
// escape the token text so its a json string
let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#);
(None, Some(vec![escaped_text]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
@ -1184,6 +1187,7 @@ fn create_event_from_stream_token(
logprobs,
finish_reason,
usage,
tool_name,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
@ -1283,6 +1287,7 @@ pub(crate) async fn chat_completions(
}
};
let mut response_as_tool = using_tools;
let mut global_function_name = String::new();
while let Some(result) = response_stream.next().await {
match result{
Ok(stream_token) => {
@ -1292,8 +1297,8 @@ pub(crate) async fn chat_completions(
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
global_function_name = captures[1].to_string();
if global_function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
@ -1302,18 +1307,7 @@ pub(crate) async fn chat_completions(
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
buffer = buffer.drain(0..1).collect();
}
}
}
@ -1348,6 +1342,7 @@ pub(crate) async fn chat_completions(
None,
None,
None,
Some(global_function_name.clone()),
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
@ -1365,17 +1360,22 @@ pub(crate) async fn chat_completions(
break;
}
// send the content
let event = create_event_from_stream_token(
&stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
buffer.push(stream_token);
// FIFO send the buffer but left the last two elements (closing '}' and EOS token)
for stream_token in &buffer[..buffer.len() - 2] {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
Some(global_function_name.clone()),
);
yield Ok::<Event, Infallible>(event);
yield Ok::<Event, Infallible>(event);
}
buffer = buffer.drain(buffer.len() - 2..).collect();
}
}
}