2024-06-11 14:44:56 +00:00
|
|
|
import pytest
|
|
|
|
import requests
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def llama_grammar_handle(launcher):
|
|
|
|
with launcher(
|
|
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
|
|
num_shard=1,
|
|
|
|
disable_grammar_support=False,
|
|
|
|
use_flash_attention=False,
|
|
|
|
max_batch_prefill_tokens=3000,
|
|
|
|
) as handle:
|
|
|
|
yield handle
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
async def llama_grammar(llama_grammar_handle):
|
|
|
|
await llama_grammar_handle.health(300)
|
|
|
|
return llama_grammar_handle.client
|
|
|
|
|
|
|
|
|
2024-06-25 14:53:20 +00:00
|
|
|
@pytest.mark.release
|
2024-06-11 14:44:56 +00:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
|
|
|
|
|
|
|
class Weather(BaseModel):
|
|
|
|
unit: str
|
|
|
|
temperature: List[int]
|
|
|
|
|
|
|
|
# send the request
|
|
|
|
response = requests.post(
|
|
|
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
|
|
|
headers=llama_grammar.headers,
|
|
|
|
json={
|
|
|
|
"model": "tgi",
|
|
|
|
"messages": [
|
|
|
|
{
|
|
|
|
"role": "system",
|
|
|
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
|
|
},
|
|
|
|
],
|
|
|
|
"seed": 42,
|
|
|
|
"max_tokens": 500,
|
|
|
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
chat_completion = response.json()
|
|
|
|
called = chat_completion["choices"][0]["message"]["content"]
|
|
|
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert (
|
|
|
|
called
|
|
|
|
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
|
|
|
)
|
|
|
|
assert chat_completion == response_snapshot
|
|
|
|
|
|
|
|
|
2024-06-25 14:53:20 +00:00
|
|
|
@pytest.mark.release
|
2024-06-11 14:44:56 +00:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
|
|
|
llama_grammar,
|
|
|
|
):
|
|
|
|
class Weather(BaseModel):
|
|
|
|
unit: str
|
|
|
|
temperature: List[int]
|
|
|
|
|
|
|
|
# send the request
|
|
|
|
response = requests.post(
|
|
|
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
|
|
|
headers=llama_grammar.headers,
|
|
|
|
json={
|
|
|
|
"model": "tgi",
|
|
|
|
"messages": [
|
|
|
|
{
|
|
|
|
"role": "system",
|
|
|
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
|
|
},
|
|
|
|
],
|
|
|
|
"seed": 42,
|
|
|
|
"max_tokens": 500,
|
|
|
|
"tools": [],
|
|
|
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
# 422 means the server was unable to process the request because it contains invalid data.
|
|
|
|
assert response.status_code == 422
|
|
|
|
assert response.json() == {
|
2024-08-12 12:08:38 +00:00
|
|
|
"error": "Tool error: Grammar and tools are mutually exclusive",
|
|
|
|
"error_type": "tool_error",
|
2024-06-11 14:44:56 +00:00
|
|
|
}
|