feat: improve partial parsing types and add test for balancing and partial parsing

This commit is contained in:
drbh 2025-02-26 18:53:12 +00:00
parent a5ddc9db52
commit 330f2e419f
2 changed files with 342 additions and 108 deletions

View File

@ -1505,6 +1505,70 @@ impl Default for ModelsInfo {
} }
} }
// balance the started json with closing braces and quotes
fn complete_json(partial: &str) -> (String, bool, bool) {
let mut brace_count = 0;
let mut quote_open = false;
let mut escaped = false;
let mut last_char = '\0';
for c in partial.chars() {
match (escaped, quote_open, c) {
(true, _, _) => escaped = false,
(false, _, '\\') => escaped = true,
(false, _, '"') => quote_open = !quote_open,
(false, false, '{') => brace_count += 1,
(false, false, '}') if brace_count > 0 => brace_count -= 1,
_ => {}
}
if !c.is_whitespace() {
last_char = c;
}
}
let mut completed = partial.to_string();
if last_char == ',' {
if let Some(pos) = completed.rfind(',') {
completed.replace_range(pos..pos + 1, "");
}
}
if quote_open {
completed.push('"');
}
if brace_count > 0 {
completed.push_str(&"}".repeat(brace_count));
}
(completed, quote_open, brace_count > 0)
}
/// Result type that includes both the parsed value and a completion status
#[derive(Debug)]
pub struct ParseResult<T> {
pub value: T,
pub last_value_whole: bool,
}
/// Parse partial JSON into a generic serializable type T
/// Returns the parsed value along with a flag indicating if completion was needed
pub fn parse_partial_json<T>(partial: &str) -> Result<ParseResult<T>, String>
where
T: for<'de> Deserialize<'de> + std::fmt::Debug,
{
let (completed, needed_close_quote, _need_close_brace) = complete_json(partial);
let obj = serde_json::from_str::<T>(&completed);
match obj {
Ok(value) => Ok(ParseResult {
value,
last_value_whole: !needed_close_quote,
}),
Err(e) => Err(format!("Failed to parse JSON: {}", e)),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -1779,3 +1843,214 @@ mod tests {
); );
} }
} }
#[cfg(test)]
mod tool_streaming_tests {
use super::*;
// Test json balancing and completion
#[test]
fn test_complete_json_basic_cases() {
// Already complete
let (completed, needed_close_quote, needed_close_brace) =
complete_json(r#"{"name":"test"}"#);
assert_eq!(completed, r#"{"name":"test"}"#);
assert_eq!(needed_close_quote, false);
assert_eq!(needed_close_brace, false);
// Missing brace
let (completed, needed_close_quote, needed_close_brace) =
complete_json(r#"{"name":"test""#);
assert_eq!(completed, r#"{"name":"test"}"#);
assert_eq!(needed_close_quote, false);
assert_eq!(needed_close_brace, true);
// Missing quote
let (completed, needed_close_quote, needed_close_brace) = complete_json(r#"{"name":"test"#);
assert_eq!(completed, r#"{"name":"test"}"#);
assert_eq!(needed_close_quote, true);
assert_eq!(needed_close_brace, true);
}
#[test]
fn test_complete_json_complex_cases() {
// Nested objects
let (completed, needed_close_quote, needed_close_brace) =
complete_json(r#"{"user":{"name":"test","age":30"#);
assert_eq!(completed, r#"{"user":{"name":"test","age":30}}"#);
assert_eq!(needed_close_quote, false);
assert_eq!(needed_close_brace, true);
// Trailing comma
let (completed, needed_close_quote, needed_close_brace) =
complete_json(r#"{"name":"test","#);
assert_eq!(completed, r#"{"name":"test"}"#);
assert_eq!(needed_close_quote, false);
assert_eq!(needed_close_brace, true);
// Escaped quotes
let (completed, needed_close_quote, needed_close_brace) =
complete_json(r#"{"message":"This is a \"quoted\" text"#);
assert_eq!(completed, r#"{"message":"This is a \"quoted\" text"}"#);
assert_eq!(needed_close_quote, true);
assert_eq!(needed_close_brace, true);
}
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct User {
name: String,
age: Option<u32>,
}
#[test]
fn test_parse_partial_json() {
// Complete
let result = parse_partial_json::<User>(r#"{"name":"Alice","age":30}"#);
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(
parsed.value,
User {
name: "Alice".to_string(),
age: Some(30)
}
);
assert_eq!(parsed.last_value_whole, true);
// Incomplete
let result = parse_partial_json::<User>(r#"{"name":"Bob","age":25"#);
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(
parsed.value,
User {
name: "Bob".to_string(),
age: Some(25)
}
);
assert_eq!(parsed.last_value_whole, true);
// Invalid
let result = parse_partial_json::<User>(r#"{"name":invalid}"#);
assert!(result.is_err());
}
#[test]
fn test_nested_escaped_quotes() {
let json = r#"{"message":"This has \"nested\" quotes"#;
let (completed, needed_close_quote, needed_close_brace) = complete_json(json);
assert_eq!(completed, r#"{"message":"This has \"nested\" quotes"}"#);
assert_eq!(needed_close_quote, true);
assert_eq!(needed_close_brace, true);
let json = r#"{"message":"This has \"nested\" quotes""#;
let (completed, needed_close_quote, needed_close_brace) = complete_json(json);
assert_eq!(completed, r#"{"message":"This has \"nested\" quotes"}"#);
assert_eq!(needed_close_quote, false);
assert_eq!(needed_close_brace, true);
}
// Test incremental JSON parsing for a tool decision
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct ToolDecision {
function: Function,
}
#[derive(Debug, Deserialize, Serialize, PartialEq)]
struct Function {
#[serde(rename = "_name")]
name: String,
}
#[test]
fn test_streaming_no_tool_flow() {
let json_buffers = [
r#"{"function": {"_name": "no_tool""#,
r#"{ "content": "I am a helpful assistant""#,
];
// Function decision
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
println!("{:?}", result);
assert!(result.is_ok());
assert_eq!(result.unwrap().value.function.name, "no_tool");
// Content
let result = parse_partial_json::<serde_json::Value>(&json_buffers[1]);
assert!(result.is_ok());
assert_eq!(result.unwrap().value["content"], "I am a helpful assistant");
}
#[test]
fn test_streaming_no_tool_flow_with_commas() {
let json_buffers = [
r#"{"function": {"_name": "no_tool","#,
r#"{ "content": "I am a helpful assistant""#,
];
// Function decision
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
println!("{:?}", result);
assert!(result.is_ok());
assert_eq!(result.unwrap().value.function.name, "no_tool");
// Content
let result = parse_partial_json::<serde_json::Value>(&json_buffers[1]);
assert!(result.is_ok());
assert_eq!(result.unwrap().value["content"], "I am a helpful assistant");
}
#[test]
fn test_streaming_weather_function() {
// Test the incremental JSON parsing for the get_current_weather function
let json_buffers = [
r#"{"function": {"_name": "get_current_weather","#,
r#"{ "location": "San Francisco, CA", "format": "fahrenheit"}}"#,
];
// Function name
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
assert!(result.is_ok());
assert_eq!(result.unwrap().value.function.name, "get_current_weather");
// Function arguments
let result = parse_partial_json::<serde_json::Value>(&json_buffers[1]);
assert!(result.is_err());
}
#[test]
fn test_incremental_json_buffers() {
// Test individual incremental steps for get_current_weather
let steps = [
r#"{"function"#,
r#"{"function": {"_name""#,
r#"{"function": {"_name": "get_current_weather""#,
r#"{"function": {"_name": "get_current_weather","#,
r#"{ "location": "San"#,
r#"{ "location": "San Francisco, CA", "format": "f"#,
r#"{ "location": "San Francisco, CA", "format": "fahrenheit"}"#,
];
// Early step should fail to parse
let result = parse_partial_json::<ToolDecision>(&steps[0]);
assert!(result.is_err());
// Middle step should parse name but be incomplete
let result = parse_partial_json::<ToolDecision>(&steps[2]);
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(parsed.value.function.name, "get_current_weather");
assert_eq!(parsed.last_value_whole, true);
// Last step should be complete and parse correctly
let result = parse_partial_json::<serde_json::Value>(&steps[6]);
assert!(result.is_ok());
let parsed = result.unwrap();
assert_eq!(
parsed.value.as_object().unwrap()["location"],
"San Francisco, CA"
);
assert_eq!(parsed.value.as_object().unwrap()["format"], "fahrenheit");
assert_eq!(parsed.last_value_whole, true);
}
}

View File

@ -13,6 +13,7 @@ use crate::sagemaker::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse; use crate::ChatTokenizeResponse;
use crate::{parse_partial_json, ParseResult};
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -47,7 +48,6 @@ use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use serde_json::Map;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::fs::File; use std::fs::File;
@ -1114,75 +1114,6 @@ pub(crate) async fn completions(
} }
} }
// balance the started json with closing braces and quotes
fn complete_json(partial: &str) -> (String, bool) {
let mut brace_count = 0;
let mut quote_open = false;
let mut escaped = false;
let mut last_char = '\0';
for c in partial.chars() {
match (escaped, quote_open, c) {
(true, _, _) => escaped = false,
(false, _, '\\') => escaped = true,
(false, _, '"') => quote_open = !quote_open,
(false, false, '{') => brace_count += 1,
(false, false, '}') => brace_count -= 1,
_ => {}
}
if !c.is_whitespace() {
last_char = c;
}
}
let mut completed = partial.to_string();
if last_char == ',' {
if let Some(pos) = completed.rfind(',') {
completed.replace_range(pos..pos + 1, "");
}
}
if quote_open {
completed.push('"');
}
completed.push_str(&"}".repeat(brace_count.max(0)));
(completed, quote_open)
}
// Generic function that parses any partial structure into a Map
fn parse_generic_structure(partial: &str) -> Result<(Map<String, Value>, bool), String> {
let (completed, quote_open) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => Ok((obj, quote_open)),
_ => Err("Failed to parse as object".to_string()),
}
}
// Parse partial JSON into a Map with a function object
fn parse_partial_json(partial: &str) -> Result<Map<String, Value>, String> {
let (completed, was_quote_open) = complete_json(partial);
match serde_json::from_str::<Value>(&completed) {
Ok(Value::Object(obj)) => {
if let Some(Value::Object(function)) = obj.get("function") {
let name_is_only_key = function.len() == 1;
if was_quote_open && name_is_only_key {
let mut function = function.clone();
if let Some(Value::String(ref mut name)) = function.get_mut("_name") {
name.clear();
}
return Err("Missing *name in function".to_string());
}
Ok(function.clone())
} else {
Err("Missing function object".to_string())
}
}
_ => Err("Failed to parse as object".to_string()),
}
}
/// Creates an event based on the token text and event type parameters. /// Creates an event based on the token text and event type parameters.
/// `token_text` - The text to include (extract from StreamResponse.token.text or str) /// `token_text` - The text to include (extract from StreamResponse.token.text or str)
/// `model_id` - Model identifier string /// `model_id` - Model identifier string
@ -1361,6 +1292,22 @@ pub(crate) async fn chat_completions(
.as_secs() .as_secs()
}; };
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct Function {
#[serde(rename = "_name")]
name: String,
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct ToolDecision {
function: Function,
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct NoToolDecision {
content: String,
}
if stream { if stream {
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(generate_request), span).await; generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
@ -1389,13 +1336,23 @@ pub(crate) async fn chat_completions(
// Phase 1: Function name discovery // Phase 1: Function name discovery
if !name_found { if !name_found {
if let Ok(function) = parse_partial_json(&json_buffer) { // NOTE: when tools are supplied `name_found` is false until the generated buffer contains
// a partial JSON object with $.function._name value. This name determines the type
// of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat
// completion event. Otherwise, we'll emit a tool call name event followed by a tool call
// argument event. In both cases we'll buffer tokens to get the name and then reset the buffer
// to collect the arguments.
if let Ok(ParseResult {
value: ToolDecision {
function: Function { name },
},
last_value_whole,
}) = parse_partial_json(&json_buffer)
{
if !last_value_whole {
continue;
}
name_found = true; name_found = true;
let name = function
.get("_name")
.and_then(|n| n.as_str())
.unwrap_or_default();
if name == "no_tool" { if name == "no_tool" {
no_tool_chosen = true; no_tool_chosen = true;
} else { } else {
@ -1403,7 +1360,7 @@ pub(crate) async fn chat_completions(
&token_text, &token_text,
&model_id, &model_id,
&system_fingerprint, &system_fingerprint,
Some(name), Some(name.as_str()),
false, false,
None, None,
)); ));
@ -1437,38 +1394,40 @@ pub(crate) async fn chat_completions(
if using_tools { if using_tools {
if no_tool_chosen && !is_complete_json { if no_tool_chosen && !is_complete_json {
// Content-only flow // Content-only flow
if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) { if let Ok(ParseResult {
if let Some(_content) = function.get("content").and_then(|c| c.as_str()) { value: _,
let cleaned_token = if !first_quote_removed { last_value_whole,
// trim start unil the first quote }) = parse_partial_json::<NoToolDecision>(&json_buffer)
first_quote_removed = true; {
edited_token let cleaned_token = if !first_quote_removed {
.trim_start() // trim start unil the first quote
.strip_prefix('"') first_quote_removed = true;
.unwrap_or(&edited_token) edited_token
.to_string() .trim_start()
} else if !quote_open { .strip_prefix('"')
should_break = true; .unwrap_or(&edited_token)
// trim end until the last quote .to_string()
edited_token } else if last_value_whole {
.trim_end() should_break = true;
.strip_suffix('"') // trim end until the last quote
.unwrap_or(&edited_token) edited_token
.to_string() .trim_end()
} else { .strip_suffix('"')
edited_token.to_string() .unwrap_or(&edited_token)
}; .to_string()
} else {
edited_token.to_string()
};
if !cleaned_token.is_empty() { if !cleaned_token.is_empty() {
events.push(create_event( events.push(create_event(
&cleaned_token, &cleaned_token,
&model_id, &model_id,
&system_fingerprint, &system_fingerprint,
None, None,
false, false,
None, None,
)); ));
}
} }
} }
} else { } else {