diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index 56920b3e..45f8ca99 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -11,13 +11,12 @@ "tool_calls": [ { "function": { - "description": null, - "name": "tools", - "parameters": { + "arguments": { "format": "celsius", - "location": "New York, NY", - "num_days": 14 - } + "location": "Brooklyn" + }, + "description": null, + "name": "get_current_weather" }, "id": 0, "type": "function" @@ -27,14 +26,14 @@ "usage": null } ], - "created": 1710795556, + "created": 1712782670, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "2.0.0-native", "usage": { - "completion_tokens": 29, - "prompt_tokens": 316, - "total_tokens": 345 + "completion_tokens": 37, + "prompt_tokens": 524, + "total_tokens": 561 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json index fe679362..e0ed0947 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -11,13 +11,12 @@ "tool_calls": [ { "function": { - "description": null, - "name": "tools", - "parameters": { + "arguments": { "format": "celsius", - "location": "New York, NY", - "num_days": 14 - } + "location": "Brooklyn" + }, + "description": null, + "name": "get_current_weather" }, "id": 0, "type": "function" @@ -27,14 +26,14 @@ "usage": null } ], - "created": 1710795557, + "created": 1712787937, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "2.0.0-native", "usage": { - "completion_tokens": 29, - "prompt_tokens": 316, - "total_tokens": 345 + "completion_tokens": 37, + "prompt_tokens": 524, + "total_tokens": 561 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json index e48a1e7d..b99343b5 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -11,12 +11,12 @@ "tool_calls": [ { "function": { - "description": null, - "name": "tools", - "parameters": { + "arguments": { "format": "celsius", "location": "New York, NY" - } + }, + "description": null, + "name": "get_current_weather" }, "id": 0, "type": "function" @@ -26,14 +26,14 @@ "usage": null } ], - "created": 1710795557, + "created": 1712787725, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "2.0.0-native", "usage": { - "completion_tokens": 21, - "prompt_tokens": 187, - "total_tokens": 208 + "completion_tokens": 48, + "prompt_tokens": 351, + "total_tokens": 399 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json new file mode 100644 index 00000000..394dc852 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json @@ -0,0 +1,39 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": { + "error": "One of the parameters (e.g. 'number_of_days') is not valid or is too few.", + "name": "notify_error" + }, + "description": null, + "name": "default_function_name" + }, + "id": 0, + "type": "function" + } + ] + }, + "usage": null + } + ], + "created": 1712788322, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.5-native", + "usage": { + "completion_tokens": 60, + "prompt_tokens": 535, + "total_tokens": 595 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index cfebc05f..6787b39b 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -19,7 +19,7 @@ "logprobs": null } ], - "created": 1710795499, + "created": 1712788218, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index d0ae331f..770acce2 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -71,7 +71,6 @@ tools = [ ] -@pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_no_tools( flash_llama_grammar_tools, response_snapshot @@ -98,7 +97,6 @@ async def test_flash_llama_grammar_no_tools( assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -121,23 +119,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content == None assert response.choices[0].message.tool_calls == [ { - "function": { - "description": None, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14, - }, - }, "id": 0, "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "Brooklyn"}, + }, } ] assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -163,23 +156,19 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content == None assert response.choices[0].message.tool_calls == [ { - "function": { - "description": None, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14, - }, - }, "id": 0, "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "Brooklyn"}, + }, } ] + assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -209,15 +198,15 @@ async def test_flash_llama_grammar_tools_choice( "type": "function", "function": { "description": None, - "name": "tools", - "parameters": {"format": "celsius", "location": "New York, NY"}, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, }, } ] + assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -246,5 +235,47 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 20 + assert count == 38 assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_insufficient_information( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=26, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=False, + ) + + assert responses.choices[0].message.content == None + assert responses.choices[0].message.tool_calls == [ + { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "default_function_name", + "arguments": { + "error": "One of the parameters (e.g. 'number_of_days') is not valid or is too few.", + "name": "notify_error", + }, + }, + } + ] + + assert responses == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index 56bb0ba4..ddb28848 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -79,7 +79,7 @@ impl HubTokenizerConfig { } } -#[derive(Clone, Debug, Deserialize, ToSchema)] +#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { /// A string that represents a [JSON Schema](https://json-schema.org/). @@ -682,7 +682,7 @@ pub(crate) struct ChatRequest { fn default_tool_prompt() -> Option { Some( - "\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), + "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } #[derive(Clone, Deserialize, ToSchema, Serialize)] @@ -780,12 +780,14 @@ pub(crate) struct Tool { pub function: FunctionDefinition, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Default)] pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, + tools: Option<&'a str>, + tools_prompt: Option<&'a str>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] diff --git a/router/src/server.rs b/router/src/server.rs index c7c42d77..c1648f9e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -757,23 +757,17 @@ async fn chat_completions( metrics::increment_counter!("tgi_request_count"); let ChatRequest { - frequency_penalty: _, - logit_bias: _, logprobs, max_tokens, messages, - model: _, - n: _, presence_penalty, seed, stop, stream, - temperature: _, tools, tool_choice, tool_prompt, - top_p: _, - top_logprobs: _, + .. } = req; let repetition_penalty = presence_penalty.map(|x| x + 2.0); @@ -798,8 +792,16 @@ async fn chat_completions( } }; + let grammar_with_prompt = tool_grammar + .as_ref() + .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); + + let typed_grammar = grammar_with_prompt + .as_ref() + .map(|(grammar, _)| grammar.clone()); + // apply chat template to flatten the request into a single input - let mut inputs = match infer.apply_chat_template(messages) { + let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -814,22 +816,6 @@ async fn chat_completions( } }; - let grammar = if let Some(tools) = &tool_grammar { - let tools_str = serde_json::to_string(&tools).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; - inputs = format!("{inputs}{tool_prompt}{tools_str}"); - Some(GrammarType::Json(serde_json::json!(tools))) - } else { - None - }; - // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), @@ -851,7 +837,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: req.top_logprobs, - grammar, + grammar: typed_grammar, }, }; @@ -934,7 +920,6 @@ async fn chat_completions( }), ) })?; - let tool_calls = vec![ToolCall { id: 0, r#type: "function".to_string(),