mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Create test_structured_output_response_format_llama.py
analogous to integration-tests/models/test_grammar_response_format_llama.py
This commit is contained in:
parent
a1803bb780
commit
9a18b75971
@ -0,0 +1,102 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
num_shard=1,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
use_flash_attention=False,
|
||||||
|
max_batch_prefill_tokens=3000,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def llama_grammar(llama_grammar_handle):
|
||||||
|
await llama_grammar_handle.health(300)
|
||||||
|
return llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_structured_output_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"response_format": {"type": "json_schema", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = response.json()
|
||||||
|
called = chat_completion["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert (
|
||||||
|
called
|
||||||
|
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
||||||
|
)
|
||||||
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_structured_output_response_format_llama_error_if_tools_not_installed(
|
||||||
|
llama_grammar,
|
||||||
|
):
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"tools": [],
|
||||||
|
"response_format": {"type": "json_schema", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 422 means the server was unable to process the request because it contains invalid data.
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert response.json() == {
|
||||||
|
"error": "Tool error: Grammar and tools are mutually exclusive",
|
||||||
|
"error_type": "tool_error",
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user