Add tests for all aliases

This commit is contained in:
Alex Weston 2025-01-30 14:11:05 -05:00
parent 67a696fad9
commit b1a9dfff21

View File

@ -29,26 +29,55 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
unit: str unit: str
temperature: List[int] temperature: List[int]
json_payload={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
}
# send the request # send the request
response = requests.post( response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions", f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers, headers=llama_grammar.headers,
json={ json=json_payload,
"model": "tgi", )
"messages": [
{ chat_completion = response.json()
"role": "system", called = chat_completion["choices"][0]["message"]["content"]
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
}, assert response.status_code == 200
{ assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
"role": "user", assert chat_completion == response_snapshot
"content": "What's the weather like the next 3 days in San Francisco, CA?",
}, json_payload["response_format"]["type"] = "json"
], response = requests.post(
"seed": 42, f"{llama_grammar.base_url}/v1/chat/completions",
"max_tokens": 500, headers=llama_grammar.headers,
"response_format": {"type": "json_object", "value": Weather.schema()}, json=json_payload,
}, )
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
json_payload["response_format"]["type"] = "json_schema"
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
) )
chat_completion = response.json() chat_completion = response.json()