text-generation-inference/integration-tests/models/test_grammar_response_format_llama.py
drbh 12ea8d74c7
Pr 2982 ci branch (#3046)
* Add json_schema alias for GrammarType

* Add tests for all aliases

* fix: various linter adjustments

* fix: end-of-file-fixer lint

* fix: add test snapshots and avoid docs change

* fix: another end-of-file-fixer lint

* feat: support json_schema grammar constraining and add tests

* fix: bump openapi doc with new grammar option

* fix: adjust test payload

* fix: bump test snaps

---------

Co-authored-by: Alex Weston <alexw@alkymi.io>
2025-05-01 10:17:16 -04:00

132 lines
4.0 KiB
Python

import pytest
import requests
from pydantic import BaseModel
from typing import List
@pytest.fixture(scope="module")
def llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
max_batch_prefill_tokens=3000,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def llama_grammar(llama_grammar_handle):
await llama_grammar_handle.health(300)
return llama_grammar_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel):
unit: str
temperature: List[int]
json_payload = {
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
}
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
json_payload["response_format"]["type"] = "json"
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
json_payload["response_format"] = {
"type": "json_schema",
"value": {"name": "weather", "strict": True, "schema": Weather.schema()},
}
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar,
):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"tools": [],
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
# 422 means the server was unable to process the request because it contains invalid data.
assert response.status_code == 422
assert response.json() == {
"error": "Tool error: Grammar and tools are mutually exclusive",
"error_type": "tool_error",
}