feat: allow json_schema in response format and add test

This commit is contained in:
drbh 2025-06-25 19:43:49 +00:00
parent 9f38d93051
commit 5f70fbdc2a
7 changed files with 247 additions and 34 deletions

View File

@ -1738,9 +1738,10 @@
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "object",
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.",
"required": [ "required": [
"type", "value",
"value" "type"
], ],
"properties": { "properties": {
"type": { "type": {
@ -1749,16 +1750,21 @@
"json" "json"
] ]
}, },
"value": { "value": {}
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions." },
"example": {
"properties": {
"location": {
"type": "string"
}
} }
} }
}, },
{ {
"type": "object", "type": "object",
"required": [ "required": [
"type", "value",
"value" "type"
], ],
"properties": { "properties": {
"type": { "type": {
@ -1773,22 +1779,25 @@
} }
}, },
{ {
"type": "object", "allOf": [
"required": [ {
"type", "$ref": "#/components/schemas/JsonSchemaFormat"
"value"
],
"properties": {
"type": {
"type": "string",
"enum": [
"json_schema"
]
}, },
"value": { {
"$ref": "#/components/schemas/JsonSchemaConfig" "type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"json_schema"
]
}
}
} }
} ]
} }
], ],
"discriminator": { "discriminator": {
@ -1898,6 +1907,40 @@
} }
} }
}, },
"JsonSchemaFormat": {
"oneOf": [
{
"type": "object",
"required": [
"json_schema"
],
"properties": {
"json_schema": {
"$ref": "#/components/schemas/JsonSchemaOrConfig"
}
}
},
{
"type": "object",
"required": [
"value"
],
"properties": {
"value": {
"$ref": "#/components/schemas/JsonSchemaOrConfig"
}
}
}
]
},
"JsonSchemaOrConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/JsonSchemaConfig"
},
{}
]
},
"Message": { "Message": {
"allOf": [ "allOf": [
{ {

View File

@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "{\"status\":\".OK.\"}",
"role": "assistant"
}
}
],
"created": 1750877897,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.4-dev0-native",
"usage": {
"completion_tokens": 8,
"prompt_tokens": 36,
"total_tokens": 44
}
}

View File

@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "{\"status\":\".OK.\"}",
"role": "assistant"
}
}
],
"created": 1750877897,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.4-dev0-native",
"usage": {
"completion_tokens": 8,
"prompt_tokens": 36,
"total_tokens": 44
}
}

View File

@ -207,3 +207,87 @@ async def test_json_schema_stream(model_fixture, response_snapshot):
assert isinstance(parsed_content["numCats"], int) assert isinstance(parsed_content["numCats"], int)
assert parsed_content["numCats"] >= 0 assert parsed_content["numCats"] >= 0
assert chunks == response_snapshot assert chunks == response_snapshot
status_schema = {
"type": "object",
"properties": {"status": {"type": "string"}},
"required": ["status"],
"additionalProperties": False,
}
@pytest.mark.asyncio
@pytest.mark.private
async def test_json_schema_simple_status(model_fixture, response_snapshot):
"""Test simple status JSON schema - TGI format."""
response = requests.post(
f"{model_fixture.base_url}/v1/chat/completions",
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'",
},
{"role": "user", "content": "Please tell me OK"},
],
"seed": 42,
"temperature": 0.0,
"response_format": {
"type": "json_schema",
"value": {"name": "test", "schema": status_schema},
},
"max_completion_tokens": 8192,
},
)
result = response.json()
# Validate response format
content = result["choices"][0]["message"]["content"]
parsed_content = json.loads(content)
assert "status" in parsed_content
assert isinstance(parsed_content["status"], str)
assert result == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_json_schema_openai_style_format(model_fixture, response_snapshot):
"""Test OpenAI-style JSON schema format (should also work now)."""
response = requests.post(
f"{model_fixture.base_url}/v1/chat/completions",
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'",
},
{"role": "user", "content": "Please tell me OK"},
],
"seed": 42,
"temperature": 0.0,
"response_format": {
"json_schema": {
"name": "test",
"strict": True,
"schema": status_schema,
},
"type": "json_schema",
},
"max_completion_tokens": 8192,
},
)
result = response.json()
# Validate response format
content = result["choices"][0]["message"]["content"]
parsed_content = json.loads(content)
assert "status" in parsed_content
assert isinstance(parsed_content["status"], str)
assert result == response_snapshot

View File

@ -224,7 +224,7 @@ impl HubProcessorConfig {
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))] #[cfg_attr(test, derive(PartialEq))]
struct JsonSchemaConfig { pub struct JsonSchemaConfig {
/// Optional name identifier for the schema /// Optional name identifier for the schema
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>, name: Option<String>,
@ -235,7 +235,7 @@ struct JsonSchemaConfig {
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))] #[cfg_attr(test, derive(PartialEq))]
#[serde(tag = "type", content = "value")] #[serde(tag = "type")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/). /// A string that represents a [JSON Schema](https://json-schema.org/).
/// ///
@ -244,17 +244,53 @@ pub(crate) enum GrammarType {
#[serde(rename = "json")] #[serde(rename = "json")]
#[serde(alias = "json_object")] #[serde(alias = "json_object")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(serde_json::Value), Json { value: serde_json::Value },
#[serde(rename = "regex")] #[serde(rename = "regex")]
Regex(String), Regex { value: String },
/// A JSON Schema specification with additional metadata. /// A JSON Schema specification with additional metadata.
/// ///
/// Includes an optional name for the schema, an optional strict flag, and the required schema definition. /// Includes an optional name for the schema, an optional strict flag, and the required schema definition.
#[serde(rename = "json_schema")] #[serde(rename = "json_schema")]
#[schema(example = json ! ({"schema": {"properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}, "name": "person_info", "strict": true}))] JsonSchema(JsonSchemaFormat),
JsonSchema(JsonSchemaConfig), }
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(untagged)]
pub enum JsonSchemaFormat {
JsonSchema { json_schema: JsonSchemaOrConfig },
Value { value: JsonSchemaOrConfig },
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(untagged)]
pub enum JsonSchemaOrConfig {
Config(JsonSchemaConfig),
Value(serde_json::Value),
}
impl JsonSchemaOrConfig {
pub fn schema_value(&self) -> &serde_json::Value {
match self {
JsonSchemaOrConfig::Config(config) => &config.schema,
JsonSchemaOrConfig::Value(value) => value,
}
}
}
impl JsonSchemaFormat {
pub fn schema_value(&self) -> &serde_json::Value {
let config = match self {
Self::JsonSchema { json_schema } | Self::Value { value: json_schema } => json_schema,
};
match config {
JsonSchemaOrConfig::Config(config) => &config.schema,
JsonSchemaOrConfig::Value(value) => value,
}
}
} }
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
@ -984,7 +1020,9 @@ impl ChatRequest {
if let Some(tools) = tools { if let Some(tools) = tools {
match ToolGrammar::apply(tools, tool_choice)? { match ToolGrammar::apply(tools, tool_choice)? {
Some((updated_tools, tool_schema)) => { Some((updated_tools, tool_schema)) => {
let grammar = GrammarType::Json(serde_json::json!(tool_schema)); let grammar = GrammarType::Json {
value: serde_json::json!(tool_schema),
};
let inputs: String = infer.apply_chat_template( let inputs: String = infer.apply_chat_template(
messages, messages,
Some((updated_tools, tool_prompt)), Some((updated_tools, tool_prompt)),

View File

@ -28,7 +28,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
}; };
use crate::{ChatTokenizeResponse, JsonSchemaConfig}; use crate::{ChatTokenizeResponse, JsonSchemaConfig, JsonSchemaFormat, JsonSchemaOrConfig};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{MessageBody, ModelInfo, ModelsInfo}; use crate::{MessageBody, ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
@ -1363,6 +1363,8 @@ SagemakerRequest,
GenerateRequest, GenerateRequest,
GrammarType, GrammarType,
JsonSchemaConfig, JsonSchemaConfig,
JsonSchemaOrConfig,
JsonSchemaFormat,
ChatRequest, ChatRequest,
Message, Message,
MessageContent, MessageContent,

View File

@ -350,13 +350,13 @@ impl Validation {
return Err(ValidationError::Grammar); return Err(ValidationError::Grammar);
} }
let valid_grammar = match grammar { let valid_grammar = match grammar {
GrammarType::Json(json) => { GrammarType::Json { value } => {
let json = match json { let json = match value {
// if value is a string, we need to parse it again to make sure its // if value is a string, we need to parse it again to make sure its
// a valid json // a valid json
Value::String(s) => serde_json::from_str(&s) Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())), .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
Value::Object(_) => Ok(json), Value::Object(_) => Ok(value),
_ => Err(ValidationError::Grammar), _ => Err(ValidationError::Grammar),
}?; }?;
@ -380,9 +380,9 @@ impl Validation {
ValidGrammar::Regex(grammar_regex.to_string()) ValidGrammar::Regex(grammar_regex.to_string())
} }
GrammarType::JsonSchema(schema_config) => { GrammarType::JsonSchema(json_schema) => {
// Extract the actual schema for validation // Extract the actual schema for validation
let json = &schema_config.schema; let json = json_schema.schema_value();
// Check if the json is a valid JSONSchema // Check if the json is a valid JSONSchema
jsonschema::draft202012::meta::validate(json) jsonschema::draft202012::meta::validate(json)
@ -402,7 +402,7 @@ impl Validation {
ValidGrammar::Regex(grammar_regex.to_string()) ValidGrammar::Regex(grammar_regex.to_string())
} }
GrammarType::Regex(regex) => ValidGrammar::Regex(regex), GrammarType::Regex { value } => ValidGrammar::Regex(value),
}; };
Some(valid_grammar) Some(valid_grammar)
} }