feat: improve chat_template to include tools

This commit is contained in:
drbh 2024-04-10 22:49:10 +00:00
parent 9874b15fa8
commit cc67f47d6e
8 changed files with 139 additions and 84 deletions

View File

@ -11,13 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY", "location": "Brooklyn"
"num_days": 14 },
} "description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
@ -27,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795556, "created": 1712782670,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 37,
"prompt_tokens": 316, "prompt_tokens": 524,
"total_tokens": 345 "total_tokens": 561
} }
} }

View File

@ -11,13 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY", "location": "Brooklyn"
"num_days": 14 },
} "description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
@ -27,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795557, "created": 1712787937,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 37,
"prompt_tokens": 316, "prompt_tokens": 524,
"total_tokens": 345 "total_tokens": 561
} }
} }

View File

@ -11,12 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY" "location": "New York, NY"
} },
"description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
@ -26,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795557, "created": 1712787725,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 21, "completion_tokens": 48,
"prompt_tokens": 187, "prompt_tokens": 351,
"total_tokens": 208 "total_tokens": 399
} }
} }

View File

@ -0,0 +1,39 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "One of the parameters (e.g. 'number_of_days') is not valid or is too few.",
"name": "notify_error"
},
"description": null,
"name": "default_function_name"
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1712788322,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 60,
"prompt_tokens": 535,
"total_tokens": 595
}
}

View File

@ -19,7 +19,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1710795499, "created": 1712788218,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -71,7 +71,6 @@ tools = [
] ]
@pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_no_tools( async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot flash_llama_grammar_tools, response_snapshot
@ -98,7 +97,6 @@ async def test_flash_llama_grammar_no_tools(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -121,23 +119,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0, "id": 0,
"type": "function", "type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn"},
},
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
@ -163,23 +156,19 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0, "id": 0,
"type": "function", "type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn"},
},
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
@ -209,15 +198,15 @@ async def test_flash_llama_grammar_tools_choice(
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "get_current_weather",
"parameters": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "New York, NY"},
}, },
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(
@ -246,5 +235,47 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses: async for response in responses:
count += 1 count += 1
assert count == 20 assert count == 38
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=26,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.tool_calls == [
{
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "default_function_name",
"arguments": {
"error": "One of the parameters (e.g. 'number_of_days') is not valid or is too few.",
"name": "notify_error",
},
},
}
]
assert responses == response_snapshot

View File

@ -79,7 +79,7 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/). /// A string that represents a [JSON Schema](https://json-schema.org/).
@ -682,7 +682,7 @@ pub(crate) struct ChatRequest {
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {
Some( Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
) )
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
@ -780,12 +780,14 @@ pub(crate) struct Tool {
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<Message>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<&'a str>,
tools_prompt: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]

View File

@ -757,23 +757,17 @@ async fn chat_completions(
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let ChatRequest { let ChatRequest {
frequency_penalty: _,
logit_bias: _,
logprobs, logprobs,
max_tokens, max_tokens,
messages, messages,
model: _,
n: _,
presence_penalty, presence_penalty,
seed, seed,
stop, stop,
stream, stream,
temperature: _,
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
top_p: _, ..
top_logprobs: _,
} = req; } = req;
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = presence_penalty.map(|x| x + 2.0);
@ -798,8 +792,16 @@ async fn chat_completions(
} }
}; };
let grammar_with_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let typed_grammar = grammar_with_prompt
.as_ref()
.map(|(grammar, _)| grammar.clone());
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(messages) { let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -814,22 +816,6 @@ async fn chat_completions(
} }
}; };
let grammar = if let Some(tools) = &tool_grammar {
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
@ -851,7 +837,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar, grammar: typed_grammar,
}, },
}; };
@ -934,7 +920,6 @@ async fn chat_completions(
}), }),
) )
})?; })?;
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: 0, id: 0,
r#type: "function".to_string(), r#type: "function".to_string(),