diff --git a/router/src/lib.rs b/router/src/lib.rs index 6d4814b13..1b1e66d2a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -765,6 +765,7 @@ impl ChatCompletionChunk { logprobs: Option, finish_reason: Option, usage: Option, + tool_name: Option, ) -> Self { let delta = match (delta, tool_calls) { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { @@ -779,7 +780,7 @@ impl ChatCompletionChunk { id: String::new(), r#type: "function".to_string(), function: Function { - name: None, + name: tool_name, arguments: tool_calls[0].to_string(), }, }], @@ -1364,7 +1365,7 @@ pub struct SimpleToken { stop: usize, } -#[derive(Debug, Serialize, ToSchema, Clone)] +#[derive(Debug, Serialize, ToSchema, Clone, PartialEq)] #[serde(rename_all(serialize = "snake_case"))] #[schema(example = "Length")] pub enum FinishReason { diff --git a/router/src/server.rs b/router/src/server.rs index e9aa4612b..925ad6249 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1128,6 +1128,7 @@ fn create_event_from_stream_token( inner_using_tools: bool, system_fingerprint: String, model_id: String, + tool_name: Option, ) -> Event { let event = Event::default(); let current_time = std::time::SystemTime::now() @@ -1141,7 +1142,9 @@ fn create_event_from_stream_token( // replace the content with the tool calls if grammar is present let (content, tool_calls) = if inner_using_tools { - (None, Some(vec![stream_token.token.text.clone()])) + // escape the token text so its a json string + let escaped_text = stream_token.token.text.replace(r#"""#, r#"\""#); + (None, Some(vec![escaped_text])) } else { let content = if !stream_token.token.special { Some(stream_token.token.text.clone()) @@ -1184,6 +1187,7 @@ fn create_event_from_stream_token( logprobs, finish_reason, usage, + tool_name, )); event.json_data(chat_complete).unwrap_or_else(|e| { @@ -1283,6 +1287,7 @@ pub(crate) async fn chat_completions( } }; let mut response_as_tool = using_tools; + let mut global_function_name = String::new(); while let Some(result) = response_stream.next().await { match result{ Ok(stream_token) => { @@ -1292,8 +1297,8 @@ pub(crate) async fn chat_completions( json_buffer.push_str(&token_text.replace(" ", "")); buffer.push(stream_token); if let Some(captures) = function_regex.captures(&json_buffer) { - let function_name = captures[1].to_string(); - if function_name == "no_tool" { + global_function_name = captures[1].to_string(); + if global_function_name == "no_tool" { state = StreamState::BufferTrailing; response_as_tool = false; buffer.clear(); @@ -1302,18 +1307,7 @@ pub(crate) async fn chat_completions( state = StreamState::Content { skip_close_quote: false, }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } + buffer = buffer.drain(0..1).collect(); } } } @@ -1348,6 +1342,7 @@ pub(crate) async fn chat_completions( None, None, None, + Some(global_function_name.clone()), )); yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { InferError::StreamSerializationError(e.to_string()).into() @@ -1365,17 +1360,22 @@ pub(crate) async fn chat_completions( break; } - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); + buffer.push(stream_token); + // FIFO send the buffer but left the last two elements (closing '}' and EOS token) + for stream_token in &buffer[..buffer.len() - 2] { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + Some(global_function_name.clone()), + ); - yield Ok::(event); + yield Ok::(event); + } + buffer = buffer.drain(buffer.len() - 2..).collect(); } } }