From 67a696fad9a7ce292429d53ae68cf706cef3cd8c Mon Sep 17 00:00:00 2001 From: Alex Weston Date: Thu, 30 Jan 2025 14:03:54 -0500 Subject: [PATCH 1/2] Add json_schema alias for GrammarType --- router/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 414d38ed6..1e5ff1701 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -209,7 +209,8 @@ pub(crate) enum GrammarType { /// /// JSON Schema is a declarative language that allows to annotate JSON documents /// with types and descriptions. - #[serde(rename = "json")] + #[serde(rename = "json_schema")] + #[serde(alias = "json")] #[serde(alias = "json_object")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] Json(serde_json::Value), From b1a9dfff216aa0d8873d9b9778df3ae91c479b4f Mon Sep 17 00:00:00 2001 From: Alex Weston Date: Thu, 30 Jan 2025 14:11:05 -0500 Subject: [PATCH 2/2] Add tests for all aliases --- .../test_grammar_response_format_llama.py | 61 ++++++++++++++----- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index f2a8a96da..809dc3dd7 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -29,26 +29,55 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh unit: str temperature: List[int] + json_payload={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + "seed": 42, + "max_tokens": 500, + "response_format": {"type": "json_object", "value": Weather.schema()}, + } # send the request response = requests.post( f"{llama_grammar.base_url}/v1/chat/completions", headers=llama_grammar.headers, - json={ - "model": "tgi", - "messages": [ - { - "role": "system", - "content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}", - }, - { - "role": "user", - "content": "What's the weather like the next 3 days in San Francisco, CA?", - }, - ], - "seed": 42, - "max_tokens": 500, - "response_format": {"type": "json_object", "value": Weather.schema()}, - }, + json=json_payload, + ) + + chat_completion = response.json() + called = chat_completion["choices"][0]["message"]["content"] + + assert response.status_code == 200 + assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }' + assert chat_completion == response_snapshot + + json_payload["response_format"]["type"] = "json" + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json=json_payload, + ) + + chat_completion = response.json() + called = chat_completion["choices"][0]["message"]["content"] + + assert response.status_code == 200 + assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }' + assert chat_completion == response_snapshot + + json_payload["response_format"]["type"] = "json_schema" + response = requests.post( + f"{llama_grammar.base_url}/v1/chat/completions", + headers=llama_grammar.headers, + json=json_payload, ) chat_completion = response.json()