fix: refactor and simplify structs and openapi

This commit is contained in:
drbh 2025-07-07 17:53:34 +00:00
parent 5f70fbdc2a
commit 43fd3bd7f4
4 changed files with 112 additions and 123 deletions

View File

@ -1779,25 +1779,20 @@
}
},
{
"allOf": [
{
"$ref": "#/components/schemas/JsonSchemaFormat"
},
{
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"json_schema"
]
}
}
"type": "object",
"required": [
"json_schema",
"type"
],
"properties": {
"json_schema": {},
"type": {
"type": "string",
"enum": [
"json_schema"
]
}
]
}
}
],
"discriminator": {
@ -1891,56 +1886,6 @@
}
}
},
"JsonSchemaConfig": {
"type": "object",
"required": [
"schema"
],
"properties": {
"name": {
"type": "string",
"description": "Optional name identifier for the schema",
"nullable": true
},
"schema": {
"description": "The actual JSON schema definition"
}
}
},
"JsonSchemaFormat": {
"oneOf": [
{
"type": "object",
"required": [
"json_schema"
],
"properties": {
"json_schema": {
"$ref": "#/components/schemas/JsonSchemaOrConfig"
}
}
},
{
"type": "object",
"required": [
"value"
],
"properties": {
"value": {
"$ref": "#/components/schemas/JsonSchemaOrConfig"
}
}
}
]
},
"JsonSchemaOrConfig": {
"oneOf": [
{
"$ref": "#/components/schemas/JsonSchemaConfig"
},
{}
]
},
"Message": {
"allOf": [
{

View File

@ -222,17 +222,6 @@ impl HubProcessorConfig {
}
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
pub struct JsonSchemaConfig {
/// Optional name identifier for the schema
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
/// The actual JSON schema definition
schema: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(tag = "type")]
@ -253,43 +242,26 @@ pub(crate) enum GrammarType {
///
/// Includes an optional name for the schema, an optional strict flag, and the required schema definition.
#[serde(rename = "json_schema")]
JsonSchema(JsonSchemaFormat),
JsonSchema {
#[serde(alias = "value")]
#[serde(deserialize_with = "custom_json_schema::deserialize_json_schema")]
json_schema: serde_json::Value,
},
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(untagged)]
pub enum JsonSchemaFormat {
JsonSchema { json_schema: JsonSchemaOrConfig },
Value { value: JsonSchemaOrConfig },
}
mod custom_json_schema {
use serde::{Deserialize, Deserializer};
use serde_json::Value;
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(untagged)]
pub enum JsonSchemaOrConfig {
Config(JsonSchemaConfig),
Value(serde_json::Value),
}
impl JsonSchemaOrConfig {
pub fn schema_value(&self) -> &serde_json::Value {
match self {
JsonSchemaOrConfig::Config(config) => &config.schema,
JsonSchemaOrConfig::Value(value) => value,
}
}
}
impl JsonSchemaFormat {
pub fn schema_value(&self) -> &serde_json::Value {
let config = match self {
Self::JsonSchema { json_schema } | Self::Value { value: json_schema } => json_schema,
};
match config {
JsonSchemaOrConfig::Config(config) => &config.schema,
JsonSchemaOrConfig::Value(value) => value,
}
pub fn deserialize_json_schema<'de, D>(deserializer: D) -> Result<Value, D::Error>
where
D: Deserializer<'de>,
{
let value: Value = Deserialize::deserialize(deserializer)?;
value
.get("schema")
.cloned()
.ok_or_else(|| serde::de::Error::custom("Expected a 'schema' field"))
}
}
@ -1874,3 +1846,79 @@ mod tests {
);
}
}
#[cfg(test)]
mod grammar_tests {
use super::*;
use serde_json::json;
#[test]
fn parse_plain_schema() {
let raw = json!({
"type": "json_schema",
"format": "plain",
"value": { "type": "integer" }
});
let parsed: GrammarType = serde_json::from_value(raw).unwrap();
println!("Parsed: {:#?}", parsed);
// match parsed {
// GrammarType::JsonSchema(JsonSchemaPayload::Plain { schema }) => {
// assert_eq!(schema, &json!({"type":"integer"}));
// }
// _ => panic!("wrong variant"),
// }
assert!(false);
}
// #[test]
// fn parse_config_schema() {
// let raw = json!({
// "type": "json_schema",
// "format": "configated",
// "name": "User",
// "strict": false,
// "schema": { "type": "object" }
// });
// let parsed: GrammarType = serde_json::from_value(raw).unwrap();
// match parsed {
// GrammarType::JsonSchema(JsonSchemaPayload::Configated(cfg)) => {
// assert_eq!(cfg.name.as_deref(), Some("User"));
// assert!(!cfg.strict);
// assert_eq!(cfg.schema, json!({"type":"object"}));
// }
// _ => panic!("wrong variant"),
// }
// }
// #[test]
// fn parse_regex() {
// let raw = json!({
// "type": "regex",
// "value": "^\\d+$"
// });
// let parsed: GrammarType = serde_json::from_value(raw).unwrap();
// match parsed {
// GrammarType::Regex { value } => assert_eq!(value, "^\\d+$"),
// _ => panic!("wrong variant"),
// }
// }
// #[test]
// fn parse_json_value() {
// let raw = json!({
// "type": "json",
// "value": { "enum": ["a", "b"] }
// });
// let parsed: GrammarType = serde_json::from_value(raw).unwrap();
// match parsed {
// GrammarType::Json { value } => assert_eq!(value, json!({"enum":["a","b"]})),
// _ => panic!("wrong variant"),
// }
// }
}

View File

@ -14,6 +14,7 @@ use crate::sagemaker::{
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -28,7 +29,6 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{ChatTokenizeResponse, JsonSchemaConfig, JsonSchemaFormat, JsonSchemaOrConfig};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{MessageBody, ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
@ -1362,9 +1362,6 @@ CompatGenerateRequest,
SagemakerRequest,
GenerateRequest,
GrammarType,
JsonSchemaConfig,
JsonSchemaOrConfig,
JsonSchemaFormat,
ChatRequest,
Message,
MessageContent,

View File

@ -380,24 +380,23 @@ impl Validation {
ValidGrammar::Regex(grammar_regex.to_string())
}
GrammarType::JsonSchema(json_schema) => {
GrammarType::JsonSchema { json_schema } => {
// Extract the actual schema for validation
let json = json_schema.schema_value();
// Check if the json is a valid JSONSchema
jsonschema::draft202012::meta::validate(json)
jsonschema::draft202012::meta::validate(&json_schema)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// The schema can be valid but lack properties.
// We need properties for the grammar to be successfully parsed in Python.
// Therefore, we must check and throw an error if properties are missing.
json.get("properties")
json_schema
.get("properties")
.ok_or(ValidationError::InvalidGrammar(
"Grammar must have a 'properties' field".to_string(),
))?;
// Do compilation in the router for performance
let grammar_regex = json_schema_to_regex(json, None, json)
let grammar_regex = json_schema_to_regex(&json_schema, None, &json_schema)
.map_err(ValidationError::RegexFromSchema)?;
ValidGrammar::Regex(grammar_regex.to_string())