mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
210 lines
6.5 KiB
Python
210 lines
6.5 KiB
Python
import pytest
|
|
import json
|
|
import requests
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def model_handle(launcher):
|
|
"""Fixture to provide the base URL for API calls."""
|
|
with launcher(
|
|
"google/gemma-3-4b-it",
|
|
num_shard=2,
|
|
disable_grammar_support=False,
|
|
) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def model_fixture(model_handle):
|
|
await model_handle.health(300)
|
|
return model_handle.client
|
|
|
|
|
|
# Sample JSON Schema for testing
|
|
person_schema = {
|
|
"type": "object",
|
|
"$id": "https://example.com/person.schema.json",
|
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
|
"title": "Person",
|
|
"properties": {
|
|
"firstName": {
|
|
"type": "string",
|
|
"description": "The person's first name.",
|
|
"minLength": 4,
|
|
},
|
|
"lastName": {
|
|
"type": "string",
|
|
"description": "The person's last name.",
|
|
"minLength": 4,
|
|
},
|
|
"hobby": {
|
|
"description": "The person's hobby.",
|
|
"type": "string",
|
|
"minLength": 4,
|
|
},
|
|
"numCats": {
|
|
"description": "The number of cats the person has.",
|
|
"type": "integer",
|
|
"minimum": 0,
|
|
},
|
|
},
|
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
|
}
|
|
|
|
# More complex schema for testing nested objects and arrays
|
|
complex_schema = {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"age": {"type": "integer", "minimum": 0},
|
|
"address": {
|
|
"type": "object",
|
|
"properties": {
|
|
"street": {"type": "string"},
|
|
"city": {"type": "string"},
|
|
"postalCode": {"type": "string"},
|
|
},
|
|
"required": ["street", "city"],
|
|
},
|
|
"hobbies": {"type": "array", "items": {"type": "string"}, "minItems": 1},
|
|
},
|
|
"required": ["name", "age", "hobbies"],
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_json_schema_basic(model_fixture, response_snapshot):
|
|
"""Test basic JSON schema validation with the person schema."""
|
|
response = requests.post(
|
|
f"{model_fixture.base_url}/v1/chat/completions",
|
|
json={
|
|
"model": "tgi",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "David is a person who likes trees and nature. He enjoys studying math and science. He has 2 cats.",
|
|
},
|
|
],
|
|
"seed": 42,
|
|
"temperature": 0.0,
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"value": {"name": "person", "strict": True, "schema": person_schema},
|
|
},
|
|
},
|
|
)
|
|
|
|
result = response.json()
|
|
|
|
# Validate response format
|
|
content = result["choices"][0]["message"]["content"]
|
|
parsed_content = json.loads(content)
|
|
|
|
assert "firstName" in parsed_content
|
|
assert "lastName" in parsed_content
|
|
assert "hobby" in parsed_content
|
|
assert "numCats" in parsed_content
|
|
assert isinstance(parsed_content["numCats"], int)
|
|
assert parsed_content["numCats"] >= 0
|
|
assert result == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_json_schema_complex(model_fixture, response_snapshot):
|
|
"""Test complex JSON schema with nested objects and arrays."""
|
|
response = requests.post(
|
|
f"{model_fixture.base_url}/v1/chat/completions",
|
|
json={
|
|
"model": "tgi",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "John Smith is 30 years old. He lives on Maple Street in Boston. He enjoys botany, astronomy, and solving mathematical puzzles.",
|
|
},
|
|
],
|
|
"seed": 42,
|
|
"temperature": 0.0,
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"value": {
|
|
"name": "complex_person",
|
|
"strict": True,
|
|
"schema": complex_schema,
|
|
},
|
|
},
|
|
},
|
|
)
|
|
|
|
result = response.json()
|
|
|
|
# Validate response format
|
|
content = result["choices"][0]["message"]["content"]
|
|
parsed_content = json.loads(content)
|
|
|
|
assert "name" in parsed_content
|
|
assert "age" in parsed_content
|
|
assert "hobbies" in parsed_content
|
|
assert "address" in parsed_content
|
|
assert "street" in parsed_content["address"]
|
|
assert "city" in parsed_content["address"]
|
|
assert isinstance(parsed_content["hobbies"], list)
|
|
assert len(parsed_content["hobbies"]) >= 1
|
|
assert result == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_json_schema_stream(model_fixture, response_snapshot):
|
|
"""Test JSON schema validation with streaming."""
|
|
response = requests.post(
|
|
f"{model_fixture.base_url}/v1/chat/completions",
|
|
json={
|
|
"model": "tgi",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "David is a person who likes to ride bicycles. He has 2 cats.",
|
|
},
|
|
],
|
|
"seed": 42,
|
|
"temperature": 0.0,
|
|
"response_format": {
|
|
"type": "json_schema",
|
|
"value": {"name": "person", "strict": True, "schema": person_schema},
|
|
},
|
|
"stream": True,
|
|
},
|
|
stream=True,
|
|
)
|
|
|
|
chunks = []
|
|
content_generated = ""
|
|
|
|
for line in response.iter_lines():
|
|
if line:
|
|
# Remove the "data: " prefix and handle the special case of "[DONE]"
|
|
data = line.decode("utf-8")
|
|
if data.startswith("data: "):
|
|
data = data[6:]
|
|
if data != "[DONE]":
|
|
chunk = json.loads(data)
|
|
chunks.append(chunk)
|
|
if "choices" in chunk and len(chunk["choices"]) > 0:
|
|
if (
|
|
"delta" in chunk["choices"][0]
|
|
and "content" in chunk["choices"][0]["delta"]
|
|
):
|
|
content_generated += chunk["choices"][0]["delta"]["content"]
|
|
|
|
# Validate the final assembled JSON
|
|
parsed_content = json.loads(content_generated)
|
|
assert "firstName" in parsed_content
|
|
assert "lastName" in parsed_content
|
|
assert "hobby" in parsed_content
|
|
assert "numCats" in parsed_content
|
|
assert isinstance(parsed_content["numCats"], int)
|
|
assert parsed_content["numCats"] >= 0
|
|
assert chunks == response_snapshot
|