mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
fix: refactor and simplify structs and openapi
This commit is contained in:
parent
5f70fbdc2a
commit
43fd3bd7f4
@ -1779,25 +1779,20 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/JsonSchemaFormat"
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"json_schema"
|
||||
]
|
||||
}
|
||||
}
|
||||
"type": "object",
|
||||
"required": [
|
||||
"json_schema",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"json_schema": {},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"json_schema"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
@ -1891,56 +1886,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"JsonSchemaConfig": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"schema"
|
||||
],
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional name identifier for the schema",
|
||||
"nullable": true
|
||||
},
|
||||
"schema": {
|
||||
"description": "The actual JSON schema definition"
|
||||
}
|
||||
}
|
||||
},
|
||||
"JsonSchemaFormat": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"json_schema"
|
||||
],
|
||||
"properties": {
|
||||
"json_schema": {
|
||||
"$ref": "#/components/schemas/JsonSchemaOrConfig"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"value"
|
||||
],
|
||||
"properties": {
|
||||
"value": {
|
||||
"$ref": "#/components/schemas/JsonSchemaOrConfig"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"JsonSchemaOrConfig": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/JsonSchemaConfig"
|
||||
},
|
||||
{}
|
||||
]
|
||||
},
|
||||
"Message": {
|
||||
"allOf": [
|
||||
{
|
||||
|
@ -222,17 +222,6 @@ impl HubProcessorConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub struct JsonSchemaConfig {
|
||||
/// Optional name identifier for the schema
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
|
||||
/// The actual JSON schema definition
|
||||
schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[serde(tag = "type")]
|
||||
@ -253,43 +242,26 @@ pub(crate) enum GrammarType {
|
||||
///
|
||||
/// Includes an optional name for the schema, an optional strict flag, and the required schema definition.
|
||||
#[serde(rename = "json_schema")]
|
||||
JsonSchema(JsonSchemaFormat),
|
||||
JsonSchema {
|
||||
#[serde(alias = "value")]
|
||||
#[serde(deserialize_with = "custom_json_schema::deserialize_json_schema")]
|
||||
json_schema: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonSchemaFormat {
|
||||
JsonSchema { json_schema: JsonSchemaOrConfig },
|
||||
Value { value: JsonSchemaOrConfig },
|
||||
}
|
||||
mod custom_json_schema {
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonSchemaOrConfig {
|
||||
Config(JsonSchemaConfig),
|
||||
Value(serde_json::Value),
|
||||
}
|
||||
|
||||
impl JsonSchemaOrConfig {
|
||||
pub fn schema_value(&self) -> &serde_json::Value {
|
||||
match self {
|
||||
JsonSchemaOrConfig::Config(config) => &config.schema,
|
||||
JsonSchemaOrConfig::Value(value) => value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchemaFormat {
|
||||
pub fn schema_value(&self) -> &serde_json::Value {
|
||||
let config = match self {
|
||||
Self::JsonSchema { json_schema } | Self::Value { value: json_schema } => json_schema,
|
||||
};
|
||||
match config {
|
||||
JsonSchemaOrConfig::Config(config) => &config.schema,
|
||||
JsonSchemaOrConfig::Value(value) => value,
|
||||
}
|
||||
pub fn deserialize_json_schema<'de, D>(deserializer: D) -> Result<Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value: Value = Deserialize::deserialize(deserializer)?;
|
||||
value
|
||||
.get("schema")
|
||||
.cloned()
|
||||
.ok_or_else(|| serde::de::Error::custom("Expected a 'schema' field"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1874,3 +1846,79 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod grammar_tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_plain_schema() {
|
||||
let raw = json!({
|
||||
"type": "json_schema",
|
||||
"format": "plain",
|
||||
"value": { "type": "integer" }
|
||||
});
|
||||
let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
println!("Parsed: {:#?}", parsed);
|
||||
|
||||
// match parsed {
|
||||
// GrammarType::JsonSchema(JsonSchemaPayload::Plain { schema }) => {
|
||||
// assert_eq!(schema, &json!({"type":"integer"}));
|
||||
// }
|
||||
// _ => panic!("wrong variant"),
|
||||
// }
|
||||
|
||||
assert!(false);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn parse_config_schema() {
|
||||
// let raw = json!({
|
||||
// "type": "json_schema",
|
||||
// "format": "configated",
|
||||
// "name": "User",
|
||||
// "strict": false,
|
||||
// "schema": { "type": "object" }
|
||||
// });
|
||||
// let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
// match parsed {
|
||||
// GrammarType::JsonSchema(JsonSchemaPayload::Configated(cfg)) => {
|
||||
// assert_eq!(cfg.name.as_deref(), Some("User"));
|
||||
// assert!(!cfg.strict);
|
||||
// assert_eq!(cfg.schema, json!({"type":"object"}));
|
||||
// }
|
||||
// _ => panic!("wrong variant"),
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn parse_regex() {
|
||||
// let raw = json!({
|
||||
// "type": "regex",
|
||||
// "value": "^\\d+$"
|
||||
// });
|
||||
// let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
// match parsed {
|
||||
// GrammarType::Regex { value } => assert_eq!(value, "^\\d+$"),
|
||||
// _ => panic!("wrong variant"),
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn parse_json_value() {
|
||||
// let raw = json!({
|
||||
// "type": "json",
|
||||
// "value": { "enum": ["a", "b"] }
|
||||
// });
|
||||
// let parsed: GrammarType = serde_json::from_value(raw).unwrap();
|
||||
|
||||
// match parsed {
|
||||
// GrammarType::Json { value } => assert_eq!(value, json!({"enum":["a","b"]})),
|
||||
// _ => panic!("wrong variant"),
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ use crate::sagemaker::{
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::vertex::vertex_compatibility;
|
||||
use crate::ChatTokenizeResponse;
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
@ -28,7 +29,6 @@ use crate::{
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||
};
|
||||
use crate::{ChatTokenizeResponse, JsonSchemaConfig, JsonSchemaFormat, JsonSchemaOrConfig};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||
use crate::{MessageBody, ModelInfo, ModelsInfo};
|
||||
use async_stream::__private::AsyncStream;
|
||||
@ -1362,9 +1362,6 @@ CompatGenerateRequest,
|
||||
SagemakerRequest,
|
||||
GenerateRequest,
|
||||
GrammarType,
|
||||
JsonSchemaConfig,
|
||||
JsonSchemaOrConfig,
|
||||
JsonSchemaFormat,
|
||||
ChatRequest,
|
||||
Message,
|
||||
MessageContent,
|
||||
|
@ -380,24 +380,23 @@ impl Validation {
|
||||
|
||||
ValidGrammar::Regex(grammar_regex.to_string())
|
||||
}
|
||||
GrammarType::JsonSchema(json_schema) => {
|
||||
GrammarType::JsonSchema { json_schema } => {
|
||||
// Extract the actual schema for validation
|
||||
let json = json_schema.schema_value();
|
||||
|
||||
// Check if the json is a valid JSONSchema
|
||||
jsonschema::draft202012::meta::validate(json)
|
||||
jsonschema::draft202012::meta::validate(&json_schema)
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
// The schema can be valid but lack properties.
|
||||
// We need properties for the grammar to be successfully parsed in Python.
|
||||
// Therefore, we must check and throw an error if properties are missing.
|
||||
json.get("properties")
|
||||
json_schema
|
||||
.get("properties")
|
||||
.ok_or(ValidationError::InvalidGrammar(
|
||||
"Grammar must have a 'properties' field".to_string(),
|
||||
))?;
|
||||
|
||||
// Do compilation in the router for performance
|
||||
let grammar_regex = json_schema_to_regex(json, None, json)
|
||||
let grammar_regex = json_schema_to_regex(&json_schema, None, &json_schema)
|
||||
.map_err(ValidationError::RegexFromSchema)?;
|
||||
|
||||
ValidGrammar::Regex(grammar_regex.to_string())
|
||||
|
Loading…
Reference in New Issue
Block a user