import pytest import json import requests @pytest.fixture(scope="module") def model_handle(launcher): """Fixture to provide the base URL for API calls.""" with launcher( "google/gemma-3-4b-it", num_shard=2, disable_grammar_support=False, ) as handle: yield handle @pytest.fixture(scope="module") async def model_fixture(model_handle): await model_handle.health(300) return model_handle.client # Sample JSON Schema for testing person_schema = { "type": "object", "$id": "https://example.com/person.schema.json", "$schema": "https://json-schema.org/draft/2020-12/schema", "title": "Person", "properties": { "firstName": { "type": "string", "description": "The person's first name.", "minLength": 4, }, "lastName": { "type": "string", "description": "The person's last name.", "minLength": 4, }, "hobby": { "description": "The person's hobby.", "type": "string", "minLength": 4, }, "numCats": { "description": "The number of cats the person has.", "type": "integer", "minimum": 0, }, }, "required": ["firstName", "lastName", "hobby", "numCats"], } # More complex schema for testing nested objects and arrays complex_schema = { "type": "object", "properties": { "name": {"type": "string"}, "age": {"type": "integer", "minimum": 0}, "address": { "type": "object", "properties": { "street": {"type": "string"}, "city": {"type": "string"}, "postalCode": {"type": "string"}, }, "required": ["street", "city"], }, "hobbies": {"type": "array", "items": {"type": "string"}, "minItems": 1}, }, "required": ["name", "age", "hobbies"], } @pytest.mark.asyncio @pytest.mark.private async def test_json_schema_basic(model_fixture, response_snapshot): """Test basic JSON schema validation with the person schema.""" response = requests.post( f"{model_fixture.base_url}/v1/chat/completions", json={ "model": "tgi", "messages": [ { "role": "user", "content": "David is a person who likes trees and nature. He enjoys studying math and science. He has 2 cats.", }, ], "seed": 42, "temperature": 0.0, "response_format": { "type": "json_schema", "value": {"name": "person", "strict": True, "schema": person_schema}, }, }, ) result = response.json() # Validate response format content = result["choices"][0]["message"]["content"] parsed_content = json.loads(content) assert "firstName" in parsed_content assert "lastName" in parsed_content assert "hobby" in parsed_content assert "numCats" in parsed_content assert isinstance(parsed_content["numCats"], int) assert parsed_content["numCats"] >= 0 assert result == response_snapshot @pytest.mark.asyncio @pytest.mark.private async def test_json_schema_complex(model_fixture, response_snapshot): """Test complex JSON schema with nested objects and arrays.""" response = requests.post( f"{model_fixture.base_url}/v1/chat/completions", json={ "model": "tgi", "messages": [ { "role": "user", "content": "John Smith is 30 years old. He lives on Maple Street in Boston. He enjoys botany, astronomy, and solving mathematical puzzles.", }, ], "seed": 42, "temperature": 0.0, "response_format": { "type": "json_schema", "value": { "name": "complex_person", "strict": True, "schema": complex_schema, }, }, }, ) result = response.json() # Validate response format content = result["choices"][0]["message"]["content"] parsed_content = json.loads(content) assert "name" in parsed_content assert "age" in parsed_content assert "hobbies" in parsed_content assert "address" in parsed_content assert "street" in parsed_content["address"] assert "city" in parsed_content["address"] assert isinstance(parsed_content["hobbies"], list) assert len(parsed_content["hobbies"]) >= 1 assert result == response_snapshot @pytest.mark.asyncio @pytest.mark.private async def test_json_schema_stream(model_fixture, response_snapshot): """Test JSON schema validation with streaming.""" response = requests.post( f"{model_fixture.base_url}/v1/chat/completions", json={ "model": "tgi", "messages": [ { "role": "user", "content": "David is a person who likes to ride bicycles. He has 2 cats.", }, ], "seed": 42, "temperature": 0.0, "response_format": { "type": "json_schema", "value": {"name": "person", "strict": True, "schema": person_schema}, }, "stream": True, }, stream=True, ) chunks = [] content_generated = "" for line in response.iter_lines(): if line: # Remove the "data: " prefix and handle the special case of "[DONE]" data = line.decode("utf-8") if data.startswith("data: "): data = data[6:] if data != "[DONE]": chunk = json.loads(data) chunks.append(chunk) if "choices" in chunk and len(chunk["choices"]) > 0: if ( "delta" in chunk["choices"][0] and "content" in chunk["choices"][0]["delta"] ): content_generated += chunk["choices"][0]["delta"]["content"] # Validate the final assembled JSON parsed_content = json.loads(content_generated) assert "firstName" in parsed_content assert "lastName" in parsed_content assert "hobby" in parsed_content assert "numCats" in parsed_content assert isinstance(parsed_content["numCats"], int) assert parsed_content["numCats"] >= 0 assert chunks == response_snapshot