diff --git a/router/src/lib.rs b/router/src/lib.rs index 1b1e66d2..dfb68b49 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1505,6 +1505,70 @@ impl Default for ModelsInfo { } } +// balance the started json with closing braces and quotes +fn complete_json(partial: &str) -> (String, bool, bool) { + let mut brace_count = 0; + let mut quote_open = false; + let mut escaped = false; + let mut last_char = '\0'; + + for c in partial.chars() { + match (escaped, quote_open, c) { + (true, _, _) => escaped = false, + (false, _, '\\') => escaped = true, + (false, _, '"') => quote_open = !quote_open, + (false, false, '{') => brace_count += 1, + (false, false, '}') if brace_count > 0 => brace_count -= 1, + _ => {} + } + if !c.is_whitespace() { + last_char = c; + } + } + + let mut completed = partial.to_string(); + + if last_char == ',' { + if let Some(pos) = completed.rfind(',') { + completed.replace_range(pos..pos + 1, ""); + } + } + + if quote_open { + completed.push('"'); + } + + if brace_count > 0 { + completed.push_str(&"}".repeat(brace_count)); + } + + (completed, quote_open, brace_count > 0) +} + +/// Result type that includes both the parsed value and a completion status +#[derive(Debug)] +pub struct ParseResult { + pub value: T, + pub last_value_whole: bool, +} + +/// Parse partial JSON into a generic serializable type T +/// Returns the parsed value along with a flag indicating if completion was needed +pub fn parse_partial_json(partial: &str) -> Result, String> +where + T: for<'de> Deserialize<'de> + std::fmt::Debug, +{ + let (completed, needed_close_quote, _need_close_brace) = complete_json(partial); + let obj = serde_json::from_str::(&completed); + match obj { + Ok(value) => Ok(ParseResult { + value, + last_value_whole: !needed_close_quote, + }), + Err(e) => Err(format!("Failed to parse JSON: {}", e)), + } +} + #[cfg(test)] mod tests { use super::*; @@ -1779,3 +1843,214 @@ mod tests { ); } } + +#[cfg(test)] +mod tool_streaming_tests { + use super::*; + + // Test json balancing and completion + #[test] + fn test_complete_json_basic_cases() { + // Already complete + let (completed, needed_close_quote, needed_close_brace) = + complete_json(r#"{"name":"test"}"#); + assert_eq!(completed, r#"{"name":"test"}"#); + assert_eq!(needed_close_quote, false); + assert_eq!(needed_close_brace, false); + + // Missing brace + let (completed, needed_close_quote, needed_close_brace) = + complete_json(r#"{"name":"test""#); + assert_eq!(completed, r#"{"name":"test"}"#); + assert_eq!(needed_close_quote, false); + assert_eq!(needed_close_brace, true); + + // Missing quote + let (completed, needed_close_quote, needed_close_brace) = complete_json(r#"{"name":"test"#); + assert_eq!(completed, r#"{"name":"test"}"#); + assert_eq!(needed_close_quote, true); + assert_eq!(needed_close_brace, true); + } + + #[test] + fn test_complete_json_complex_cases() { + // Nested objects + let (completed, needed_close_quote, needed_close_brace) = + complete_json(r#"{"user":{"name":"test","age":30"#); + assert_eq!(completed, r#"{"user":{"name":"test","age":30}}"#); + assert_eq!(needed_close_quote, false); + assert_eq!(needed_close_brace, true); + + // Trailing comma + let (completed, needed_close_quote, needed_close_brace) = + complete_json(r#"{"name":"test","#); + assert_eq!(completed, r#"{"name":"test"}"#); + assert_eq!(needed_close_quote, false); + assert_eq!(needed_close_brace, true); + + // Escaped quotes + let (completed, needed_close_quote, needed_close_brace) = + complete_json(r#"{"message":"This is a \"quoted\" text"#); + assert_eq!(completed, r#"{"message":"This is a \"quoted\" text"}"#); + assert_eq!(needed_close_quote, true); + assert_eq!(needed_close_brace, true); + } + + #[derive(Debug, Deserialize, Serialize, PartialEq)] + struct User { + name: String, + age: Option, + } + + #[test] + fn test_parse_partial_json() { + // Complete + let result = parse_partial_json::(r#"{"name":"Alice","age":30}"#); + assert!(result.is_ok()); + let parsed = result.unwrap(); + assert_eq!( + parsed.value, + User { + name: "Alice".to_string(), + age: Some(30) + } + ); + assert_eq!(parsed.last_value_whole, true); + + // Incomplete + let result = parse_partial_json::(r#"{"name":"Bob","age":25"#); + assert!(result.is_ok()); + let parsed = result.unwrap(); + assert_eq!( + parsed.value, + User { + name: "Bob".to_string(), + age: Some(25) + } + ); + assert_eq!(parsed.last_value_whole, true); + + // Invalid + let result = parse_partial_json::(r#"{"name":invalid}"#); + assert!(result.is_err()); + } + + #[test] + fn test_nested_escaped_quotes() { + let json = r#"{"message":"This has \"nested\" quotes"#; + let (completed, needed_close_quote, needed_close_brace) = complete_json(json); + assert_eq!(completed, r#"{"message":"This has \"nested\" quotes"}"#); + assert_eq!(needed_close_quote, true); + assert_eq!(needed_close_brace, true); + + let json = r#"{"message":"This has \"nested\" quotes""#; + let (completed, needed_close_quote, needed_close_brace) = complete_json(json); + assert_eq!(completed, r#"{"message":"This has \"nested\" quotes"}"#); + assert_eq!(needed_close_quote, false); + assert_eq!(needed_close_brace, true); + } + + // Test incremental JSON parsing for a tool decision + #[derive(Debug, Deserialize, Serialize, PartialEq)] + struct ToolDecision { + function: Function, + } + + #[derive(Debug, Deserialize, Serialize, PartialEq)] + struct Function { + #[serde(rename = "_name")] + name: String, + } + + #[test] + fn test_streaming_no_tool_flow() { + let json_buffers = [ + r#"{"function": {"_name": "no_tool""#, + r#"{ "content": "I am a helpful assistant""#, + ]; + + // Function decision + let result = parse_partial_json::(&json_buffers[0]); + println!("{:?}", result); + assert!(result.is_ok()); + assert_eq!(result.unwrap().value.function.name, "no_tool"); + + // Content + let result = parse_partial_json::(&json_buffers[1]); + assert!(result.is_ok()); + assert_eq!(result.unwrap().value["content"], "I am a helpful assistant"); + } + + #[test] + fn test_streaming_no_tool_flow_with_commas() { + let json_buffers = [ + r#"{"function": {"_name": "no_tool","#, + r#"{ "content": "I am a helpful assistant""#, + ]; + + // Function decision + let result = parse_partial_json::(&json_buffers[0]); + println!("{:?}", result); + assert!(result.is_ok()); + assert_eq!(result.unwrap().value.function.name, "no_tool"); + + // Content + let result = parse_partial_json::(&json_buffers[1]); + assert!(result.is_ok()); + assert_eq!(result.unwrap().value["content"], "I am a helpful assistant"); + } + + #[test] + fn test_streaming_weather_function() { + // Test the incremental JSON parsing for the get_current_weather function + let json_buffers = [ + r#"{"function": {"_name": "get_current_weather","#, + r#"{ "location": "San Francisco, CA", "format": "fahrenheit"}}"#, + ]; + + // Function name + let result = parse_partial_json::(&json_buffers[0]); + assert!(result.is_ok()); + assert_eq!(result.unwrap().value.function.name, "get_current_weather"); + + // Function arguments + let result = parse_partial_json::(&json_buffers[1]); + assert!(result.is_err()); + } + + #[test] + fn test_incremental_json_buffers() { + // Test individual incremental steps for get_current_weather + let steps = [ + r#"{"function"#, + r#"{"function": {"_name""#, + r#"{"function": {"_name": "get_current_weather""#, + r#"{"function": {"_name": "get_current_weather","#, + r#"{ "location": "San"#, + r#"{ "location": "San Francisco, CA", "format": "f"#, + r#"{ "location": "San Francisco, CA", "format": "fahrenheit"}"#, + ]; + + // Early step should fail to parse + let result = parse_partial_json::(&steps[0]); + assert!(result.is_err()); + + // Middle step should parse name but be incomplete + let result = parse_partial_json::(&steps[2]); + assert!(result.is_ok()); + let parsed = result.unwrap(); + assert_eq!(parsed.value.function.name, "get_current_weather"); + assert_eq!(parsed.last_value_whole, true); + + // Last step should be complete and parse correctly + let result = parse_partial_json::(&steps[6]); + assert!(result.is_ok()); + let parsed = result.unwrap(); + assert_eq!( + parsed.value.as_object().unwrap()["location"], + "San Francisco, CA" + ); + assert_eq!(parsed.value.as_object().unwrap()["format"], "fahrenheit"); + assert_eq!(parsed.last_value_whole, true); + } +} diff --git a/router/src/server.rs b/router/src/server.rs index 99a87b7a..de0cf53d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,6 +13,7 @@ use crate::sagemaker::{ use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; +use crate::{parse_partial_json, ParseResult}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -47,7 +48,6 @@ use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use pyo3::prelude::*; use pyo3::types::IntoPyDict; -use serde_json::Map; use serde_json::Value; use std::convert::Infallible; use std::fs::File; @@ -1114,75 +1114,6 @@ pub(crate) async fn completions( } } -// balance the started json with closing braces and quotes -fn complete_json(partial: &str) -> (String, bool) { - let mut brace_count = 0; - let mut quote_open = false; - let mut escaped = false; - let mut last_char = '\0'; - - for c in partial.chars() { - match (escaped, quote_open, c) { - (true, _, _) => escaped = false, - (false, _, '\\') => escaped = true, - (false, _, '"') => quote_open = !quote_open, - (false, false, '{') => brace_count += 1, - (false, false, '}') => brace_count -= 1, - _ => {} - } - if !c.is_whitespace() { - last_char = c; - } - } - - let mut completed = partial.to_string(); - - if last_char == ',' { - if let Some(pos) = completed.rfind(',') { - completed.replace_range(pos..pos + 1, ""); - } - } - - if quote_open { - completed.push('"'); - } - completed.push_str(&"}".repeat(brace_count.max(0))); - - (completed, quote_open) -} - -// Generic function that parses any partial structure into a Map -fn parse_generic_structure(partial: &str) -> Result<(Map, bool), String> { - let (completed, quote_open) = complete_json(partial); - match serde_json::from_str::(&completed) { - Ok(Value::Object(obj)) => Ok((obj, quote_open)), - _ => Err("Failed to parse as object".to_string()), - } -} - -// Parse partial JSON into a Map with a function object -fn parse_partial_json(partial: &str) -> Result, String> { - let (completed, was_quote_open) = complete_json(partial); - match serde_json::from_str::(&completed) { - Ok(Value::Object(obj)) => { - if let Some(Value::Object(function)) = obj.get("function") { - let name_is_only_key = function.len() == 1; - if was_quote_open && name_is_only_key { - let mut function = function.clone(); - if let Some(Value::String(ref mut name)) = function.get_mut("_name") { - name.clear(); - } - return Err("Missing *name in function".to_string()); - } - Ok(function.clone()) - } else { - Err("Missing function object".to_string()) - } - } - _ => Err("Failed to parse as object".to_string()), - } -} - /// Creates an event based on the token text and event type parameters. /// `token_text` - The text to include (extract from StreamResponse.token.text or str) /// `model_id` - Model identifier string @@ -1361,6 +1292,22 @@ pub(crate) async fn chat_completions( .as_secs() }; + #[derive(serde::Serialize, serde::Deserialize, Debug)] + pub struct Function { + #[serde(rename = "_name")] + name: String, + } + + #[derive(serde::Serialize, serde::Deserialize, Debug)] + pub struct ToolDecision { + function: Function, + } + + #[derive(serde::Serialize, serde::Deserialize, Debug)] + pub struct NoToolDecision { + content: String, + } + if stream { let (headers, response_stream) = generate_stream_internal(infer, compute_type, Json(generate_request), span).await; @@ -1389,13 +1336,23 @@ pub(crate) async fn chat_completions( // Phase 1: Function name discovery if !name_found { - if let Ok(function) = parse_partial_json(&json_buffer) { + // NOTE: when tools are supplied `name_found` is false until the generated buffer contains + // a partial JSON object with $.function._name value. This name determines the type + // of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat + // completion event. Otherwise, we'll emit a tool call name event followed by a tool call + // argument event. In both cases we'll buffer tokens to get the name and then reset the buffer + // to collect the arguments. + if let Ok(ParseResult { + value: ToolDecision { + function: Function { name }, + }, + last_value_whole, + }) = parse_partial_json(&json_buffer) + { + if !last_value_whole { + continue; + } name_found = true; - - let name = function - .get("_name") - .and_then(|n| n.as_str()) - .unwrap_or_default(); if name == "no_tool" { no_tool_chosen = true; } else { @@ -1403,7 +1360,7 @@ pub(crate) async fn chat_completions( &token_text, &model_id, &system_fingerprint, - Some(name), + Some(name.as_str()), false, None, )); @@ -1437,38 +1394,40 @@ pub(crate) async fn chat_completions( if using_tools { if no_tool_chosen && !is_complete_json { // Content-only flow - if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) { - if let Some(_content) = function.get("content").and_then(|c| c.as_str()) { - let cleaned_token = if !first_quote_removed { - // trim start unil the first quote - first_quote_removed = true; - edited_token - .trim_start() - .strip_prefix('"') - .unwrap_or(&edited_token) - .to_string() - } else if !quote_open { - should_break = true; - // trim end until the last quote - edited_token - .trim_end() - .strip_suffix('"') - .unwrap_or(&edited_token) - .to_string() - } else { - edited_token.to_string() - }; + if let Ok(ParseResult { + value: _, + last_value_whole, + }) = parse_partial_json::(&json_buffer) + { + let cleaned_token = if !first_quote_removed { + // trim start unil the first quote + first_quote_removed = true; + edited_token + .trim_start() + .strip_prefix('"') + .unwrap_or(&edited_token) + .to_string() + } else if last_value_whole { + should_break = true; + // trim end until the last quote + edited_token + .trim_end() + .strip_suffix('"') + .unwrap_or(&edited_token) + .to_string() + } else { + edited_token.to_string() + }; - if !cleaned_token.is_empty() { - events.push(create_event( - &cleaned_token, - &model_id, - &system_fingerprint, - None, - false, - None, - )); - } + if !cleaned_token.is_empty() { + events.push(create_event( + &cleaned_token, + &model_id, + &system_fingerprint, + None, + false, + None, + )); } } } else {