mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
feat: improve partial parsing types and add test for balancing and partial parsing
This commit is contained in:
parent
a5ddc9db52
commit
330f2e419f
@ -1505,6 +1505,70 @@ impl Default for ModelsInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// balance the started json with closing braces and quotes
|
||||
fn complete_json(partial: &str) -> (String, bool, bool) {
|
||||
let mut brace_count = 0;
|
||||
let mut quote_open = false;
|
||||
let mut escaped = false;
|
||||
let mut last_char = '\0';
|
||||
|
||||
for c in partial.chars() {
|
||||
match (escaped, quote_open, c) {
|
||||
(true, _, _) => escaped = false,
|
||||
(false, _, '\\') => escaped = true,
|
||||
(false, _, '"') => quote_open = !quote_open,
|
||||
(false, false, '{') => brace_count += 1,
|
||||
(false, false, '}') if brace_count > 0 => brace_count -= 1,
|
||||
_ => {}
|
||||
}
|
||||
if !c.is_whitespace() {
|
||||
last_char = c;
|
||||
}
|
||||
}
|
||||
|
||||
let mut completed = partial.to_string();
|
||||
|
||||
if last_char == ',' {
|
||||
if let Some(pos) = completed.rfind(',') {
|
||||
completed.replace_range(pos..pos + 1, "");
|
||||
}
|
||||
}
|
||||
|
||||
if quote_open {
|
||||
completed.push('"');
|
||||
}
|
||||
|
||||
if brace_count > 0 {
|
||||
completed.push_str(&"}".repeat(brace_count));
|
||||
}
|
||||
|
||||
(completed, quote_open, brace_count > 0)
|
||||
}
|
||||
|
||||
/// Result type that includes both the parsed value and a completion status
|
||||
#[derive(Debug)]
|
||||
pub struct ParseResult<T> {
|
||||
pub value: T,
|
||||
pub last_value_whole: bool,
|
||||
}
|
||||
|
||||
/// Parse partial JSON into a generic serializable type T
|
||||
/// Returns the parsed value along with a flag indicating if completion was needed
|
||||
pub fn parse_partial_json<T>(partial: &str) -> Result<ParseResult<T>, String>
|
||||
where
|
||||
T: for<'de> Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
let (completed, needed_close_quote, _need_close_brace) = complete_json(partial);
|
||||
let obj = serde_json::from_str::<T>(&completed);
|
||||
match obj {
|
||||
Ok(value) => Ok(ParseResult {
|
||||
value,
|
||||
last_value_whole: !needed_close_quote,
|
||||
}),
|
||||
Err(e) => Err(format!("Failed to parse JSON: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -1779,3 +1843,214 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tool_streaming_tests {
|
||||
use super::*;
|
||||
|
||||
// Test json balancing and completion
|
||||
#[test]
|
||||
fn test_complete_json_basic_cases() {
|
||||
// Already complete
|
||||
let (completed, needed_close_quote, needed_close_brace) =
|
||||
complete_json(r#"{"name":"test"}"#);
|
||||
assert_eq!(completed, r#"{"name":"test"}"#);
|
||||
assert_eq!(needed_close_quote, false);
|
||||
assert_eq!(needed_close_brace, false);
|
||||
|
||||
// Missing brace
|
||||
let (completed, needed_close_quote, needed_close_brace) =
|
||||
complete_json(r#"{"name":"test""#);
|
||||
assert_eq!(completed, r#"{"name":"test"}"#);
|
||||
assert_eq!(needed_close_quote, false);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
|
||||
// Missing quote
|
||||
let (completed, needed_close_quote, needed_close_brace) = complete_json(r#"{"name":"test"#);
|
||||
assert_eq!(completed, r#"{"name":"test"}"#);
|
||||
assert_eq!(needed_close_quote, true);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complete_json_complex_cases() {
|
||||
// Nested objects
|
||||
let (completed, needed_close_quote, needed_close_brace) =
|
||||
complete_json(r#"{"user":{"name":"test","age":30"#);
|
||||
assert_eq!(completed, r#"{"user":{"name":"test","age":30}}"#);
|
||||
assert_eq!(needed_close_quote, false);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
|
||||
// Trailing comma
|
||||
let (completed, needed_close_quote, needed_close_brace) =
|
||||
complete_json(r#"{"name":"test","#);
|
||||
assert_eq!(completed, r#"{"name":"test"}"#);
|
||||
assert_eq!(needed_close_quote, false);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
|
||||
// Escaped quotes
|
||||
let (completed, needed_close_quote, needed_close_brace) =
|
||||
complete_json(r#"{"message":"This is a \"quoted\" text"#);
|
||||
assert_eq!(completed, r#"{"message":"This is a \"quoted\" text"}"#);
|
||||
assert_eq!(needed_close_quote, true);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq)]
|
||||
struct User {
|
||||
name: String,
|
||||
age: Option<u32>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_partial_json() {
|
||||
// Complete
|
||||
let result = parse_partial_json::<User>(r#"{"name":"Alice","age":30}"#);
|
||||
assert!(result.is_ok());
|
||||
let parsed = result.unwrap();
|
||||
assert_eq!(
|
||||
parsed.value,
|
||||
User {
|
||||
name: "Alice".to_string(),
|
||||
age: Some(30)
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.last_value_whole, true);
|
||||
|
||||
// Incomplete
|
||||
let result = parse_partial_json::<User>(r#"{"name":"Bob","age":25"#);
|
||||
assert!(result.is_ok());
|
||||
let parsed = result.unwrap();
|
||||
assert_eq!(
|
||||
parsed.value,
|
||||
User {
|
||||
name: "Bob".to_string(),
|
||||
age: Some(25)
|
||||
}
|
||||
);
|
||||
assert_eq!(parsed.last_value_whole, true);
|
||||
|
||||
// Invalid
|
||||
let result = parse_partial_json::<User>(r#"{"name":invalid}"#);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_escaped_quotes() {
|
||||
let json = r#"{"message":"This has \"nested\" quotes"#;
|
||||
let (completed, needed_close_quote, needed_close_brace) = complete_json(json);
|
||||
assert_eq!(completed, r#"{"message":"This has \"nested\" quotes"}"#);
|
||||
assert_eq!(needed_close_quote, true);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
|
||||
let json = r#"{"message":"This has \"nested\" quotes""#;
|
||||
let (completed, needed_close_quote, needed_close_brace) = complete_json(json);
|
||||
assert_eq!(completed, r#"{"message":"This has \"nested\" quotes"}"#);
|
||||
assert_eq!(needed_close_quote, false);
|
||||
assert_eq!(needed_close_brace, true);
|
||||
}
|
||||
|
||||
// Test incremental JSON parsing for a tool decision
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq)]
|
||||
struct ToolDecision {
|
||||
function: Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq)]
|
||||
struct Function {
|
||||
#[serde(rename = "_name")]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_no_tool_flow() {
|
||||
let json_buffers = [
|
||||
r#"{"function": {"_name": "no_tool""#,
|
||||
r#"{ "content": "I am a helpful assistant""#,
|
||||
];
|
||||
|
||||
// Function decision
|
||||
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
||||
println!("{:?}", result);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().value.function.name, "no_tool");
|
||||
|
||||
// Content
|
||||
let result = parse_partial_json::<serde_json::Value>(&json_buffers[1]);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().value["content"], "I am a helpful assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_no_tool_flow_with_commas() {
|
||||
let json_buffers = [
|
||||
r#"{"function": {"_name": "no_tool","#,
|
||||
r#"{ "content": "I am a helpful assistant""#,
|
||||
];
|
||||
|
||||
// Function decision
|
||||
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
||||
println!("{:?}", result);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().value.function.name, "no_tool");
|
||||
|
||||
// Content
|
||||
let result = parse_partial_json::<serde_json::Value>(&json_buffers[1]);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().value["content"], "I am a helpful assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_weather_function() {
|
||||
// Test the incremental JSON parsing for the get_current_weather function
|
||||
let json_buffers = [
|
||||
r#"{"function": {"_name": "get_current_weather","#,
|
||||
r#"{ "location": "San Francisco, CA", "format": "fahrenheit"}}"#,
|
||||
];
|
||||
|
||||
// Function name
|
||||
let result = parse_partial_json::<ToolDecision>(&json_buffers[0]);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().value.function.name, "get_current_weather");
|
||||
|
||||
// Function arguments
|
||||
let result = parse_partial_json::<serde_json::Value>(&json_buffers[1]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_incremental_json_buffers() {
|
||||
// Test individual incremental steps for get_current_weather
|
||||
let steps = [
|
||||
r#"{"function"#,
|
||||
r#"{"function": {"_name""#,
|
||||
r#"{"function": {"_name": "get_current_weather""#,
|
||||
r#"{"function": {"_name": "get_current_weather","#,
|
||||
r#"{ "location": "San"#,
|
||||
r#"{ "location": "San Francisco, CA", "format": "f"#,
|
||||
r#"{ "location": "San Francisco, CA", "format": "fahrenheit"}"#,
|
||||
];
|
||||
|
||||
// Early step should fail to parse
|
||||
let result = parse_partial_json::<ToolDecision>(&steps[0]);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Middle step should parse name but be incomplete
|
||||
let result = parse_partial_json::<ToolDecision>(&steps[2]);
|
||||
assert!(result.is_ok());
|
||||
let parsed = result.unwrap();
|
||||
assert_eq!(parsed.value.function.name, "get_current_weather");
|
||||
assert_eq!(parsed.last_value_whole, true);
|
||||
|
||||
// Last step should be complete and parse correctly
|
||||
let result = parse_partial_json::<serde_json::Value>(&steps[6]);
|
||||
assert!(result.is_ok());
|
||||
let parsed = result.unwrap();
|
||||
assert_eq!(
|
||||
parsed.value.as_object().unwrap()["location"],
|
||||
"San Francisco, CA"
|
||||
);
|
||||
assert_eq!(parsed.value.as_object().unwrap()["format"], "fahrenheit");
|
||||
assert_eq!(parsed.last_value_whole, true);
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ use crate::sagemaker::{
|
||||
use crate::validation::ValidationError;
|
||||
use crate::vertex::vertex_compatibility;
|
||||
use crate::ChatTokenizeResponse;
|
||||
use crate::{parse_partial_json, ParseResult};
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
@ -47,7 +48,6 @@ use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::IntoPyDict;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
@ -1114,75 +1114,6 @@ pub(crate) async fn completions(
|
||||
}
|
||||
}
|
||||
|
||||
// balance the started json with closing braces and quotes
|
||||
fn complete_json(partial: &str) -> (String, bool) {
|
||||
let mut brace_count = 0;
|
||||
let mut quote_open = false;
|
||||
let mut escaped = false;
|
||||
let mut last_char = '\0';
|
||||
|
||||
for c in partial.chars() {
|
||||
match (escaped, quote_open, c) {
|
||||
(true, _, _) => escaped = false,
|
||||
(false, _, '\\') => escaped = true,
|
||||
(false, _, '"') => quote_open = !quote_open,
|
||||
(false, false, '{') => brace_count += 1,
|
||||
(false, false, '}') => brace_count -= 1,
|
||||
_ => {}
|
||||
}
|
||||
if !c.is_whitespace() {
|
||||
last_char = c;
|
||||
}
|
||||
}
|
||||
|
||||
let mut completed = partial.to_string();
|
||||
|
||||
if last_char == ',' {
|
||||
if let Some(pos) = completed.rfind(',') {
|
||||
completed.replace_range(pos..pos + 1, "");
|
||||
}
|
||||
}
|
||||
|
||||
if quote_open {
|
||||
completed.push('"');
|
||||
}
|
||||
completed.push_str(&"}".repeat(brace_count.max(0)));
|
||||
|
||||
(completed, quote_open)
|
||||
}
|
||||
|
||||
// Generic function that parses any partial structure into a Map
|
||||
fn parse_generic_structure(partial: &str) -> Result<(Map<String, Value>, bool), String> {
|
||||
let (completed, quote_open) = complete_json(partial);
|
||||
match serde_json::from_str::<Value>(&completed) {
|
||||
Ok(Value::Object(obj)) => Ok((obj, quote_open)),
|
||||
_ => Err("Failed to parse as object".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
// Parse partial JSON into a Map with a function object
|
||||
fn parse_partial_json(partial: &str) -> Result<Map<String, Value>, String> {
|
||||
let (completed, was_quote_open) = complete_json(partial);
|
||||
match serde_json::from_str::<Value>(&completed) {
|
||||
Ok(Value::Object(obj)) => {
|
||||
if let Some(Value::Object(function)) = obj.get("function") {
|
||||
let name_is_only_key = function.len() == 1;
|
||||
if was_quote_open && name_is_only_key {
|
||||
let mut function = function.clone();
|
||||
if let Some(Value::String(ref mut name)) = function.get_mut("_name") {
|
||||
name.clear();
|
||||
}
|
||||
return Err("Missing *name in function".to_string());
|
||||
}
|
||||
Ok(function.clone())
|
||||
} else {
|
||||
Err("Missing function object".to_string())
|
||||
}
|
||||
}
|
||||
_ => Err("Failed to parse as object".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an event based on the token text and event type parameters.
|
||||
/// `token_text` - The text to include (extract from StreamResponse.token.text or str)
|
||||
/// `model_id` - Model identifier string
|
||||
@ -1361,6 +1292,22 @@ pub(crate) async fn chat_completions(
|
||||
.as_secs()
|
||||
};
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct Function {
|
||||
#[serde(rename = "_name")]
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct ToolDecision {
|
||||
function: Function,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
pub struct NoToolDecision {
|
||||
content: String,
|
||||
}
|
||||
|
||||
if stream {
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||
@ -1389,13 +1336,23 @@ pub(crate) async fn chat_completions(
|
||||
|
||||
// Phase 1: Function name discovery
|
||||
if !name_found {
|
||||
if let Ok(function) = parse_partial_json(&json_buffer) {
|
||||
// NOTE: when tools are supplied `name_found` is false until the generated buffer contains
|
||||
// a partial JSON object with $.function._name value. This name determines the type
|
||||
// of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat
|
||||
// completion event. Otherwise, we'll emit a tool call name event followed by a tool call
|
||||
// argument event. In both cases we'll buffer tokens to get the name and then reset the buffer
|
||||
// to collect the arguments.
|
||||
if let Ok(ParseResult {
|
||||
value: ToolDecision {
|
||||
function: Function { name },
|
||||
},
|
||||
last_value_whole,
|
||||
}) = parse_partial_json(&json_buffer)
|
||||
{
|
||||
if !last_value_whole {
|
||||
continue;
|
||||
}
|
||||
name_found = true;
|
||||
|
||||
let name = function
|
||||
.get("_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or_default();
|
||||
if name == "no_tool" {
|
||||
no_tool_chosen = true;
|
||||
} else {
|
||||
@ -1403,7 +1360,7 @@ pub(crate) async fn chat_completions(
|
||||
&token_text,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
Some(name),
|
||||
Some(name.as_str()),
|
||||
false,
|
||||
None,
|
||||
));
|
||||
@ -1437,38 +1394,40 @@ pub(crate) async fn chat_completions(
|
||||
if using_tools {
|
||||
if no_tool_chosen && !is_complete_json {
|
||||
// Content-only flow
|
||||
if let Ok((function, quote_open)) = parse_generic_structure(&json_buffer) {
|
||||
if let Some(_content) = function.get("content").and_then(|c| c.as_str()) {
|
||||
let cleaned_token = if !first_quote_removed {
|
||||
// trim start unil the first quote
|
||||
first_quote_removed = true;
|
||||
edited_token
|
||||
.trim_start()
|
||||
.strip_prefix('"')
|
||||
.unwrap_or(&edited_token)
|
||||
.to_string()
|
||||
} else if !quote_open {
|
||||
should_break = true;
|
||||
// trim end until the last quote
|
||||
edited_token
|
||||
.trim_end()
|
||||
.strip_suffix('"')
|
||||
.unwrap_or(&edited_token)
|
||||
.to_string()
|
||||
} else {
|
||||
edited_token.to_string()
|
||||
};
|
||||
if let Ok(ParseResult {
|
||||
value: _,
|
||||
last_value_whole,
|
||||
}) = parse_partial_json::<NoToolDecision>(&json_buffer)
|
||||
{
|
||||
let cleaned_token = if !first_quote_removed {
|
||||
// trim start unil the first quote
|
||||
first_quote_removed = true;
|
||||
edited_token
|
||||
.trim_start()
|
||||
.strip_prefix('"')
|
||||
.unwrap_or(&edited_token)
|
||||
.to_string()
|
||||
} else if last_value_whole {
|
||||
should_break = true;
|
||||
// trim end until the last quote
|
||||
edited_token
|
||||
.trim_end()
|
||||
.strip_suffix('"')
|
||||
.unwrap_or(&edited_token)
|
||||
.to_string()
|
||||
} else {
|
||||
edited_token.to_string()
|
||||
};
|
||||
|
||||
if !cleaned_token.is_empty() {
|
||||
events.push(create_event(
|
||||
&cleaned_token,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
));
|
||||
}
|
||||
if !cleaned_token.is_empty() {
|
||||
events.push(create_event(
|
||||
&cleaned_token,
|
||||
&model_id,
|
||||
&system_fingerprint,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user