diff --git a/docs/openapi.json b/docs/openapi.json index 63572257..9aaffaee 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1738,9 +1738,10 @@ "oneOf": [ { "type": "object", + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", "required": [ - "type", - "value" + "value", + "type" ], "properties": { "type": { @@ -1749,16 +1750,21 @@ "json" ] }, - "value": { - "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions." + "value": {} + }, + "example": { + "properties": { + "location": { + "type": "string" + } } } }, { "type": "object", "required": [ - "type", - "value" + "value", + "type" ], "properties": { "type": { @@ -1773,22 +1779,25 @@ } }, { - "type": "object", - "required": [ - "type", - "value" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "json_schema" - ] + "allOf": [ + { + "$ref": "#/components/schemas/JsonSchemaFormat" }, - "value": { - "$ref": "#/components/schemas/JsonSchemaConfig" + { + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "json_schema" + ] + } + } } - } + ] } ], "discriminator": { @@ -1898,6 +1907,40 @@ } } }, + "JsonSchemaFormat": { + "oneOf": [ + { + "type": "object", + "required": [ + "json_schema" + ], + "properties": { + "json_schema": { + "$ref": "#/components/schemas/JsonSchemaOrConfig" + } + } + }, + { + "type": "object", + "required": [ + "value" + ], + "properties": { + "value": { + "$ref": "#/components/schemas/JsonSchemaOrConfig" + } + } + } + ] + }, + "JsonSchemaOrConfig": { + "oneOf": [ + { + "$ref": "#/components/schemas/JsonSchemaConfig" + }, + {} + ] + }, "Message": { "allOf": [ { diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_openai_style_format.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_openai_style_format.json new file mode 100644 index 00000000..be6bd4f9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_openai_style_format.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "{\"status\":\".OK.\"}", + "role": "assistant" + } + } + ], + "created": 1750877897, + "id": "", + "model": "google/gemma-3-4b-it", + "object": "chat.completion", + "system_fingerprint": "3.3.4-dev0-native", + "usage": { + "completion_tokens": 8, + "prompt_tokens": 36, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_simple_status.json b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_simple_status.json new file mode 100644 index 00000000..be6bd4f9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_simple_status.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "{\"status\":\".OK.\"}", + "role": "assistant" + } + } + ], + "created": 1750877897, + "id": "", + "model": "google/gemma-3-4b-it", + "object": "chat.completion", + "system_fingerprint": "3.3.4-dev0-native", + "usage": { + "completion_tokens": 8, + "prompt_tokens": 36, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/test_json_schema_constrain.py b/integration-tests/models/test_json_schema_constrain.py index 65b4a7b8..0aa91de0 100644 --- a/integration-tests/models/test_json_schema_constrain.py +++ b/integration-tests/models/test_json_schema_constrain.py @@ -207,3 +207,87 @@ async def test_json_schema_stream(model_fixture, response_snapshot): assert isinstance(parsed_content["numCats"], int) assert parsed_content["numCats"] >= 0 assert chunks == response_snapshot + + +status_schema = { + "type": "object", + "properties": {"status": {"type": "string"}}, + "required": ["status"], + "additionalProperties": False, +} + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_json_schema_simple_status(model_fixture, response_snapshot): + """Test simple status JSON schema - TGI format.""" + response = requests.post( + f"{model_fixture.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'", + }, + {"role": "user", "content": "Please tell me OK"}, + ], + "seed": 42, + "temperature": 0.0, + "response_format": { + "type": "json_schema", + "value": {"name": "test", "schema": status_schema}, + }, + "max_completion_tokens": 8192, + }, + ) + + result = response.json() + + # Validate response format + content = result["choices"][0]["message"]["content"] + parsed_content = json.loads(content) + + assert "status" in parsed_content + assert isinstance(parsed_content["status"], str) + assert result == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_json_schema_openai_style_format(model_fixture, response_snapshot): + """Test OpenAI-style JSON schema format (should also work now).""" + response = requests.post( + f"{model_fixture.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'", + }, + {"role": "user", "content": "Please tell me OK"}, + ], + "seed": 42, + "temperature": 0.0, + "response_format": { + "json_schema": { + "name": "test", + "strict": True, + "schema": status_schema, + }, + "type": "json_schema", + }, + "max_completion_tokens": 8192, + }, + ) + + result = response.json() + + # Validate response format + content = result["choices"][0]["message"]["content"] + parsed_content = json.loads(content) + + assert "status" in parsed_content + assert isinstance(parsed_content["status"], str) + assert result == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc2..3ad7a3d3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -224,7 +224,7 @@ impl HubProcessorConfig { #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[cfg_attr(test, derive(PartialEq))] -struct JsonSchemaConfig { +pub struct JsonSchemaConfig { /// Optional name identifier for the schema #[serde(skip_serializing_if = "Option::is_none")] name: Option, @@ -235,7 +235,7 @@ struct JsonSchemaConfig { #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[cfg_attr(test, derive(PartialEq))] -#[serde(tag = "type", content = "value")] +#[serde(tag = "type")] pub(crate) enum GrammarType { /// A string that represents a [JSON Schema](https://json-schema.org/). /// @@ -244,17 +244,53 @@ pub(crate) enum GrammarType { #[serde(rename = "json")] #[serde(alias = "json_object")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] - Json(serde_json::Value), + Json { value: serde_json::Value }, #[serde(rename = "regex")] - Regex(String), + Regex { value: String }, /// A JSON Schema specification with additional metadata. /// /// Includes an optional name for the schema, an optional strict flag, and the required schema definition. #[serde(rename = "json_schema")] - #[schema(example = json ! ({"schema": {"properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}, "name": "person_info", "strict": true}))] - JsonSchema(JsonSchemaConfig), + JsonSchema(JsonSchemaFormat), +} + +#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] +#[cfg_attr(test, derive(PartialEq))] +#[serde(untagged)] +pub enum JsonSchemaFormat { + JsonSchema { json_schema: JsonSchemaOrConfig }, + Value { value: JsonSchemaOrConfig }, +} + +#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] +#[cfg_attr(test, derive(PartialEq))] +#[serde(untagged)] +pub enum JsonSchemaOrConfig { + Config(JsonSchemaConfig), + Value(serde_json::Value), +} + +impl JsonSchemaOrConfig { + pub fn schema_value(&self) -> &serde_json::Value { + match self { + JsonSchemaOrConfig::Config(config) => &config.schema, + JsonSchemaOrConfig::Value(value) => value, + } + } +} + +impl JsonSchemaFormat { + pub fn schema_value(&self) -> &serde_json::Value { + let config = match self { + Self::JsonSchema { json_schema } | Self::Value { value: json_schema } => json_schema, + }; + match config { + JsonSchemaOrConfig::Config(config) => &config.schema, + JsonSchemaOrConfig::Value(value) => value, + } + } } #[derive(Clone, Debug, Serialize, ToSchema)] @@ -984,7 +1020,9 @@ impl ChatRequest { if let Some(tools) = tools { match ToolGrammar::apply(tools, tool_choice)? { Some((updated_tools, tool_schema)) => { - let grammar = GrammarType::Json(serde_json::json!(tool_schema)); + let grammar = GrammarType::Json { + value: serde_json::json!(tool_schema), + }; let inputs: String = infer.apply_chat_template( messages, Some((updated_tools, tool_prompt)), diff --git a/router/src/server.rs b/router/src/server.rs index 5fbe0403..2d152ff1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -28,7 +28,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{ChatTokenizeResponse, JsonSchemaConfig}; +use crate::{ChatTokenizeResponse, JsonSchemaConfig, JsonSchemaFormat, JsonSchemaOrConfig}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{MessageBody, ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; @@ -1363,6 +1363,8 @@ SagemakerRequest, GenerateRequest, GrammarType, JsonSchemaConfig, +JsonSchemaOrConfig, +JsonSchemaFormat, ChatRequest, Message, MessageContent, diff --git a/router/src/validation.rs b/router/src/validation.rs index 28c7f2f8..b784138c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -350,13 +350,13 @@ impl Validation { return Err(ValidationError::Grammar); } let valid_grammar = match grammar { - GrammarType::Json(json) => { - let json = match json { + GrammarType::Json { value } => { + let json = match value { // if value is a string, we need to parse it again to make sure its // a valid json Value::String(s) => serde_json::from_str(&s) .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), - Value::Object(_) => Ok(json), + Value::Object(_) => Ok(value), _ => Err(ValidationError::Grammar), }?; @@ -380,9 +380,9 @@ impl Validation { ValidGrammar::Regex(grammar_regex.to_string()) } - GrammarType::JsonSchema(schema_config) => { + GrammarType::JsonSchema(json_schema) => { // Extract the actual schema for validation - let json = &schema_config.schema; + let json = json_schema.schema_value(); // Check if the json is a valid JSONSchema jsonschema::draft202012::meta::validate(json) @@ -402,7 +402,7 @@ impl Validation { ValidGrammar::Regex(grammar_regex.to_string()) } - GrammarType::Regex(regex) => ValidGrammar::Regex(regex), + GrammarType::Regex { value } => ValidGrammar::Regex(value), }; Some(valid_grammar) }