mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
Merge da47e5754b
into 06d9d88b95
This commit is contained in:
commit
cbcd2eebeb
@ -1738,9 +1738,10 @@
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.",
|
||||
"required": [
|
||||
"type",
|
||||
"value"
|
||||
"value",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"type": {
|
||||
@ -1749,16 +1750,21 @@
|
||||
"json"
|
||||
]
|
||||
},
|
||||
"value": {
|
||||
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions."
|
||||
"value": {}
|
||||
},
|
||||
"example": {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"type",
|
||||
"value"
|
||||
"value",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"type": {
|
||||
@ -1775,18 +1781,16 @@
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"type",
|
||||
"value"
|
||||
"json_schema",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"json_schema": {},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"json_schema"
|
||||
]
|
||||
},
|
||||
"value": {
|
||||
"$ref": "#/components/schemas/JsonSchemaConfig"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1882,22 +1886,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"JsonSchemaConfig": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"schema"
|
||||
],
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional name identifier for the schema",
|
||||
"nullable": true
|
||||
},
|
||||
"schema": {
|
||||
"description": "The actual JSON schema definition"
|
||||
}
|
||||
}
|
||||
},
|
||||
"Message": {
|
||||
"allOf": [
|
||||
{
|
||||
|
@ -0,0 +1,23 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "{\"status\":\".OK.\"}",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1750877897,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.4-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 8,
|
||||
"prompt_tokens": 36,
|
||||
"total_tokens": 44
|
||||
}
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "{\"status\":\".OK.\"}",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1750877897,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.4-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 8,
|
||||
"prompt_tokens": 36,
|
||||
"total_tokens": 44
|
||||
}
|
||||
}
|
@ -207,3 +207,87 @@ async def test_json_schema_stream(model_fixture, response_snapshot):
|
||||
assert isinstance(parsed_content["numCats"], int)
|
||||
assert parsed_content["numCats"] >= 0
|
||||
assert chunks == response_snapshot
|
||||
|
||||
|
||||
status_schema = {
|
||||
"type": "object",
|
||||
"properties": {"status": {"type": "string"}},
|
||||
"required": ["status"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_json_schema_simple_status(model_fixture, response_snapshot):
|
||||
"""Test simple status JSON schema - TGI format."""
|
||||
response = requests.post(
|
||||
f"{model_fixture.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'",
|
||||
},
|
||||
{"role": "user", "content": "Please tell me OK"},
|
||||
],
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"value": {"name": "test", "schema": status_schema},
|
||||
},
|
||||
"max_completion_tokens": 8192,
|
||||
},
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Validate response format
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
parsed_content = json.loads(content)
|
||||
|
||||
assert "status" in parsed_content
|
||||
assert isinstance(parsed_content["status"], str)
|
||||
assert result == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_json_schema_openai_style_format(model_fixture, response_snapshot):
|
||||
"""Test OpenAI-style JSON schema format (should also work now)."""
|
||||
response = requests.post(
|
||||
f"{model_fixture.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. You answer with a JSON output with a status string containing the value 'OK'",
|
||||
},
|
||||
{"role": "user", "content": "Please tell me OK"},
|
||||
],
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
"response_format": {
|
||||
"json_schema": {
|
||||
"name": "test",
|
||||
"strict": True,
|
||||
"schema": status_schema,
|
||||
},
|
||||
"type": "json_schema",
|
||||
},
|
||||
"max_completion_tokens": 8192,
|
||||
},
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
# Validate response format
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
parsed_content = json.loads(content)
|
||||
|
||||
assert "status" in parsed_content
|
||||
assert isinstance(parsed_content["status"], str)
|
||||
assert result == response_snapshot
|
||||
|
@ -224,18 +224,7 @@ impl HubProcessorConfig {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
struct JsonSchemaConfig {
|
||||
/// Optional name identifier for the schema
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
|
||||
/// The actual JSON schema definition
|
||||
schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
#[serde(tag = "type")]
|
||||
pub(crate) enum GrammarType {
|
||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||
///
|
||||
@ -244,17 +233,36 @@ pub(crate) enum GrammarType {
|
||||
#[serde(rename = "json")]
|
||||
#[serde(alias = "json_object")]
|
||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||
Json(serde_json::Value),
|
||||
Json { value: serde_json::Value },
|
||||
|
||||
#[serde(rename = "regex")]
|
||||
Regex(String),
|
||||
Regex { value: String },
|
||||
|
||||
/// A JSON Schema specification with additional metadata.
|
||||
///
|
||||
/// Includes an optional name for the schema, an optional strict flag, and the required schema definition.
|
||||
#[serde(rename = "json_schema")]
|
||||
#[schema(example = json ! ({"schema": {"properties": {"name": {"type": "string"}, "age": {"type": "integer"}}}, "name": "person_info", "strict": true}))]
|
||||
JsonSchema(JsonSchemaConfig),
|
||||
JsonSchema {
|
||||
#[serde(alias = "value")]
|
||||
#[serde(deserialize_with = "custom_json_schema::deserialize_json_schema")]
|
||||
json_schema: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
mod custom_json_schema {
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize_json_schema<'de, D>(deserializer: D) -> Result<Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value: Value = Deserialize::deserialize(deserializer)?;
|
||||
value
|
||||
.get("schema")
|
||||
.cloned()
|
||||
.ok_or_else(|| serde::de::Error::custom("Expected a 'schema' field"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
@ -984,7 +992,9 @@ impl ChatRequest {
|
||||
if let Some(tools) = tools {
|
||||
match ToolGrammar::apply(tools, tool_choice)? {
|
||||
Some((updated_tools, tool_schema)) => {
|
||||
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
|
||||
let grammar = GrammarType::Json {
|
||||
value: serde_json::json!(tool_schema),
|
||||
};
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt)),
|
||||
@ -1836,3 +1846,80 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod grammar_tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_regex() {
|
||||
let raw = json!({
|
||||
"type": "regex",
|
||||
"value": "^\\d+$"
|
||||
});
|
||||
let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
match parsed {
|
||||
GrammarType::Regex { value } => assert_eq!(value, "^\\d+$"),
|
||||
_ => panic!("Expected Regex variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_value() {
|
||||
let raw = json!({
|
||||
"type": "json",
|
||||
"value": { "enum": ["a", "b"] }
|
||||
});
|
||||
let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
match parsed {
|
||||
GrammarType::Json { value } => assert_eq!(value, json!({"enum":["a","b"]})),
|
||||
_ => panic!("Expected Json variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_json_schema() {
|
||||
let raw = json!({
|
||||
"type": "json_schema",
|
||||
"json_schema": { "schema": {"type":"integer"} }
|
||||
});
|
||||
let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
match parsed {
|
||||
GrammarType::JsonSchema { json_schema } => {
|
||||
assert_eq!(json_schema, json!({"type":"integer"}));
|
||||
}
|
||||
_ => panic!("Expected JsonSchema variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_regex_ip_address() {
|
||||
let raw = json!({
|
||||
"type": "regex",
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)"
|
||||
});
|
||||
let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
match parsed {
|
||||
GrammarType::Regex { value } => {
|
||||
assert!(value.contains("25[0-5]"));
|
||||
}
|
||||
_ => panic!("Expected Regex variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_invalid_type_should_fail() {
|
||||
let raw = json!({
|
||||
"type": "invalid_type",
|
||||
"value": "test"
|
||||
});
|
||||
|
||||
let result: Result<GrammarType, _> = serde_json::from_value(raw);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ use crate::sagemaker::{
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::vertex::vertex_compatibility;
|
||||
use crate::ChatTokenizeResponse;
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
@ -28,7 +29,6 @@ use crate::{
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||
};
|
||||
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||
use crate::{MessageBody, ModelInfo, ModelsInfo};
|
||||
use async_stream::__private::AsyncStream;
|
||||
@ -1362,7 +1362,6 @@ CompatGenerateRequest,
|
||||
SagemakerRequest,
|
||||
GenerateRequest,
|
||||
GrammarType,
|
||||
JsonSchemaConfig,
|
||||
ChatRequest,
|
||||
Message,
|
||||
MessageContent,
|
||||
|
@ -350,13 +350,13 @@ impl Validation {
|
||||
return Err(ValidationError::Grammar);
|
||||
}
|
||||
let valid_grammar = match grammar {
|
||||
GrammarType::Json(json) => {
|
||||
let json = match json {
|
||||
GrammarType::Json { value } => {
|
||||
let json = match value {
|
||||
// if value is a string, we need to parse it again to make sure its
|
||||
// a valid json
|
||||
Value::String(s) => serde_json::from_str(&s)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
|
||||
Value::Object(_) => Ok(json),
|
||||
Value::Object(_) => Ok(value),
|
||||
_ => Err(ValidationError::Grammar),
|
||||
}?;
|
||||
|
||||
@ -380,29 +380,28 @@ impl Validation {
|
||||
|
||||
ValidGrammar::Regex(grammar_regex.to_string())
|
||||
}
|
||||
GrammarType::JsonSchema(schema_config) => {
|
||||
GrammarType::JsonSchema { json_schema } => {
|
||||
// Extract the actual schema for validation
|
||||
let json = &schema_config.schema;
|
||||
|
||||
// Check if the json is a valid JSONSchema
|
||||
jsonschema::draft202012::meta::validate(json)
|
||||
jsonschema::draft202012::meta::validate(&json_schema)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
// The schema can be valid but lack properties.
|
||||
// We need properties for the grammar to be successfully parsed in Python.
|
||||
// Therefore, we must check and throw an error if properties are missing.
|
||||
json.get("properties")
|
||||
json_schema
|
||||
.get("properties")
|
||||
.ok_or(ValidationError::InvalidGrammar(
|
||||
"Grammar must have a 'properties' field".to_string(),
|
||||
))?;
|
||||
|
||||
// Do compilation in the router for performance
|
||||
let grammar_regex = json_schema_to_regex(json, None, json)
|
||||
let grammar_regex = json_schema_to_regex(&json_schema, None, &json_schema)
|
||||
.map_err(ValidationError::RegexFromSchema)?;
|
||||
|
||||
ValidGrammar::Regex(grammar_regex.to_string())
|
||||
}
|
||||
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
||||
GrammarType::Regex { value } => ValidGrammar::Regex(value),
|
||||
};
|
||||
Some(valid_grammar)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user