mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat: allow json_schema in response format and add test
This commit is contained in:
parent
9f38d93051
commit
5f70fbdc2a
@ -1738,9 +1738,10 @@
|
|||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.",
|
||||||
"required": [
|
"required": [
|
||||||
"type",
|
"value",
|
||||||
"value"
|
"type"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"type": {
|
"type": {
|
||||||
@ -1749,16 +1750,21 @@
|
|||||||
"json"
|
"json"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"value": {
|
"value": {}
|
||||||
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions."
|
},
|
||||||
|
"example": {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"type",
|
"value",
|
||||||
"value"
|
"type"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"type": {
|
"type": {
|
||||||
@ -1773,22 +1779,25 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "object",
|
"allOf": [
|
||||||
"required": [
|
{
|
||||||
"type",
|
"$ref": "#/components/schemas/JsonSchemaFormat"
|
||||||
"value"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"json_schema"
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"value": {
|
{
|
||||||
"$ref": "#/components/schemas/JsonSchemaConfig"
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"json_schema"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
@ -1898,6 +1907,40 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"JsonSchemaFormat": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"json_schema"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"json_schema": {
|
||||||
|
"$ref": "#/components/schemas/JsonSchemaOrConfig"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"value"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"value": {
|
||||||
|
"$ref": "#/components/schemas/JsonSchemaOrConfig"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"JsonSchemaOrConfig": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/JsonSchemaConfig"
|
||||||
|
},
|
||||||
|
{}
|
||||||
|
]
|
||||||
|
},
|
||||||
"Message": {
|
"Message": {
|
||||||
"allOf": [
|
"allOf": [
|
||||||
{
|
{
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\"status\":\".OK.\"}",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1750877897,
|
||||||
|
"id": "",
|
||||||
|
"model": "google/gemma-3-4b-it",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.3.4-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 8,
|
||||||
|
"prompt_tokens": 36,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\"status\":\".OK.\"}",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1750877897,
|
||||||
|
"id": "",
|
||||||
|
"model": "google/gemma-3-4b-it",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.3.4-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 8,
|
||||||
|
"prompt_tokens": 36,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
}
|
@ -207,3 +207,87 @@ async def test_json_schema_stream(model_fixture, response_snapshot):
|
|||||||
assert isinstance(parsed_content["numCats"], int)
|
assert isinstance(parsed_content["numCats"], int)
|
||||||
assert parsed_content["numCats"] >= 0
|
assert parsed_content["numCats"] >= 0
|
||||||
assert chunks == response_snapshot
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
status_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"status": {"type": "string"}},
|
||||||
|
"required": ["status"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_json_schema_simple_status(model_fixture, response_snapshot):
|
||||||
|
"""Test simple status JSON schema - TGI format."""
|
||||||
|
response = requests.post(
|
||||||
|
f"{model_fixture.base_url}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Please tell me OK"},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"response_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"value": {"name": "test", "schema": status_schema},
|
||||||
|
},
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Validate response format
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
|
||||||
|
assert "status" in parsed_content
|
||||||
|
assert isinstance(parsed_content["status"], str)
|
||||||
|
assert result == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_json_schema_openai_style_format(model_fixture, response_snapshot):
|
||||||
|
"""Test OpenAI-style JSON schema format (should also work now)."""
|
||||||
|
response = requests.post(
|
||||||
|
f"{model_fixture.base_url}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Please tell me OK"},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"response_format": {
|
||||||
|
"json_schema": {
|
||||||
|
"name": "test",
|
||||||
|
"strict": True,
|
||||||
|
"schema": status_schema,
|
||||||
|
},
|
||||||
|
"type": "json_schema",
|
||||||
|
},
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Validate response format
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
|
||||||
|
assert "status" in parsed_content
|
||||||
|
assert isinstance(parsed_content["status"], str)
|
||||||
|
assert result == response_snapshot
|
||||||
|
@ -224,7 +224,7 @@ impl HubProcessorConfig {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
#[cfg_attr(test, derive(PartialEq))]
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
struct JsonSchemaConfig {
|
pub struct JsonSchemaConfig {
|
||||||
/// Optional name identifier for the schema
|
/// Optional name identifier for the schema
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
name: Option<String>,
|
name: Option<String>,
|
||||||
@ -235,7 +235,7 @@ struct JsonSchemaConfig {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
#[cfg_attr(test, derive(PartialEq))]
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||||
///
|
///
|
||||||
@ -244,17 +244,53 @@ pub(crate) enum GrammarType {
|
|||||||
#[serde(rename = "json")]
|
#[serde(rename = "json")]
|
||||||
#[serde(alias = "json_object")]
|
#[serde(alias = "json_object")]
|
||||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||||
Json(serde_json::Value),
|
Json { value: serde_json::Value },
|
||||||
|
|
||||||
#[serde(rename = "regex")]
|
#[serde(rename = "regex")]
|
||||||
Regex(String),
|
Regex { value: String },
|
||||||
|
|
||||||
/// A JSON Schema specification with additional metadata.
|
/// A JSON Schema specification with additional metadata.
|
||||||
///
|
///
|
||||||
/// Includes an optional name for the schema, an optional strict flag, and the required schema definition.
|
/// Includes an optional name for the schema, an optional strict flag, and the required schema definition.
|
||||||
#[serde(rename = "json_schema")]
|
#[serde(rename = "json_schema")]
|
||||||
#[schema(example = json ! ({"schema": {"properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}, "name": "person_info", "strict": true}))]
|
JsonSchema(JsonSchemaFormat),
|
||||||
JsonSchema(JsonSchemaConfig),
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum JsonSchemaFormat {
|
||||||
|
JsonSchema { json_schema: JsonSchemaOrConfig },
|
||||||
|
Value { value: JsonSchemaOrConfig },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum JsonSchemaOrConfig {
|
||||||
|
Config(JsonSchemaConfig),
|
||||||
|
Value(serde_json::Value),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JsonSchemaOrConfig {
|
||||||
|
pub fn schema_value(&self) -> &serde_json::Value {
|
||||||
|
match self {
|
||||||
|
JsonSchemaOrConfig::Config(config) => &config.schema,
|
||||||
|
JsonSchemaOrConfig::Value(value) => value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JsonSchemaFormat {
|
||||||
|
pub fn schema_value(&self) -> &serde_json::Value {
|
||||||
|
let config = match self {
|
||||||
|
Self::JsonSchema { json_schema } | Self::Value { value: json_schema } => json_schema,
|
||||||
|
};
|
||||||
|
match config {
|
||||||
|
JsonSchemaOrConfig::Config(config) => &config.schema,
|
||||||
|
JsonSchemaOrConfig::Value(value) => value,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
@ -984,7 +1020,9 @@ impl ChatRequest {
|
|||||||
if let Some(tools) = tools {
|
if let Some(tools) = tools {
|
||||||
match ToolGrammar::apply(tools, tool_choice)? {
|
match ToolGrammar::apply(tools, tool_choice)? {
|
||||||
Some((updated_tools, tool_schema)) => {
|
Some((updated_tools, tool_schema)) => {
|
||||||
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
|
let grammar = GrammarType::Json {
|
||||||
|
value: serde_json::json!(tool_schema),
|
||||||
|
};
|
||||||
let inputs: String = infer.apply_chat_template(
|
let inputs: String = infer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
Some((updated_tools, tool_prompt)),
|
Some((updated_tools, tool_prompt)),
|
||||||
|
@ -28,7 +28,7 @@ use crate::{
|
|||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
};
|
};
|
||||||
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
|
use crate::{ChatTokenizeResponse, JsonSchemaConfig, JsonSchemaFormat, JsonSchemaOrConfig};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||||
use crate::{MessageBody, ModelInfo, ModelsInfo};
|
use crate::{MessageBody, ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
@ -1363,6 +1363,8 @@ SagemakerRequest,
|
|||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GrammarType,
|
GrammarType,
|
||||||
JsonSchemaConfig,
|
JsonSchemaConfig,
|
||||||
|
JsonSchemaOrConfig,
|
||||||
|
JsonSchemaFormat,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
Message,
|
Message,
|
||||||
MessageContent,
|
MessageContent,
|
||||||
|
@ -350,13 +350,13 @@ impl Validation {
|
|||||||
return Err(ValidationError::Grammar);
|
return Err(ValidationError::Grammar);
|
||||||
}
|
}
|
||||||
let valid_grammar = match grammar {
|
let valid_grammar = match grammar {
|
||||||
GrammarType::Json(json) => {
|
GrammarType::Json { value } => {
|
||||||
let json = match json {
|
let json = match value {
|
||||||
// if value is a string, we need to parse it again to make sure its
|
// if value is a string, we need to parse it again to make sure its
|
||||||
// a valid json
|
// a valid json
|
||||||
Value::String(s) => serde_json::from_str(&s)
|
Value::String(s) => serde_json::from_str(&s)
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
|
||||||
Value::Object(_) => Ok(json),
|
Value::Object(_) => Ok(value),
|
||||||
_ => Err(ValidationError::Grammar),
|
_ => Err(ValidationError::Grammar),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
@ -380,9 +380,9 @@ impl Validation {
|
|||||||
|
|
||||||
ValidGrammar::Regex(grammar_regex.to_string())
|
ValidGrammar::Regex(grammar_regex.to_string())
|
||||||
}
|
}
|
||||||
GrammarType::JsonSchema(schema_config) => {
|
GrammarType::JsonSchema(json_schema) => {
|
||||||
// Extract the actual schema for validation
|
// Extract the actual schema for validation
|
||||||
let json = &schema_config.schema;
|
let json = json_schema.schema_value();
|
||||||
|
|
||||||
// Check if the json is a valid JSONSchema
|
// Check if the json is a valid JSONSchema
|
||||||
jsonschema::draft202012::meta::validate(json)
|
jsonschema::draft202012::meta::validate(json)
|
||||||
@ -402,7 +402,7 @@ impl Validation {
|
|||||||
|
|
||||||
ValidGrammar::Regex(grammar_regex.to_string())
|
ValidGrammar::Regex(grammar_regex.to_string())
|
||||||
}
|
}
|
||||||
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
GrammarType::Regex { value } => ValidGrammar::Regex(value),
|
||||||
};
|
};
|
||||||
Some(valid_grammar)
|
Some(valid_grammar)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user