mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
* Move JSON grammar -> regex grammar conversion to the router This change moves the JSON grammar -> regex grammar conversion to the router by adding a dependency on the `outlines-core` Rust crate. In contrast to the Python implementation, the conversions are not LRU-cached since they seem to be fast enough: simple schema time: [5.8293 µs 5.8307 µs 5.8320 µs] change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05) Performance has improved. complex schema time: [14.875 µs 14.881 µs 14.887 µs] change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05) Performance has improved. Using the schemas from: https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
100 lines
3.0 KiB
Python
100 lines
3.0 KiB
Python
import pytest
|
|
import requests
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def llama_grammar_handle(launcher):
|
|
with launcher(
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
num_shard=1,
|
|
disable_grammar_support=False,
|
|
use_flash_attention=False,
|
|
max_batch_prefill_tokens=3000,
|
|
) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def llama_grammar(llama_grammar_handle):
|
|
await llama_grammar_handle.health(300)
|
|
return llama_grammar_handle.client
|
|
|
|
|
|
@pytest.mark.release
|
|
@pytest.mark.asyncio
|
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
|
class Weather(BaseModel):
|
|
unit: str
|
|
temperature: List[int]
|
|
|
|
# send the request
|
|
response = requests.post(
|
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
|
headers=llama_grammar.headers,
|
|
json={
|
|
"model": "tgi",
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
},
|
|
],
|
|
"seed": 42,
|
|
"max_tokens": 500,
|
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
|
},
|
|
)
|
|
|
|
chat_completion = response.json()
|
|
called = chat_completion["choices"][0]["message"]["content"]
|
|
|
|
assert response.status_code == 200
|
|
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
|
|
assert chat_completion == response_snapshot
|
|
|
|
|
|
@pytest.mark.release
|
|
@pytest.mark.asyncio
|
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
|
llama_grammar,
|
|
):
|
|
class Weather(BaseModel):
|
|
unit: str
|
|
temperature: List[int]
|
|
|
|
# send the request
|
|
response = requests.post(
|
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
|
headers=llama_grammar.headers,
|
|
json={
|
|
"model": "tgi",
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
|
},
|
|
],
|
|
"seed": 42,
|
|
"max_tokens": 500,
|
|
"tools": [],
|
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
|
},
|
|
)
|
|
|
|
# 422 means the server was unable to process the request because it contains invalid data.
|
|
assert response.status_code == 422
|
|
assert response.json() == {
|
|
"error": "Tool error: Grammar and tools are mutually exclusive",
|
|
"error_type": "tool_error",
|
|
}
|