mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve chat_template to include tools
This commit is contained in:
parent
9874b15fa8
commit
cc67f47d6e
@ -11,13 +11,12 @@
|
|||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"function": {
|
"function": {
|
||||||
"description": null,
|
"arguments": {
|
||||||
"name": "tools",
|
|
||||||
"parameters": {
|
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "New York, NY",
|
"location": "Brooklyn"
|
||||||
"num_days": 14
|
},
|
||||||
}
|
"description": null,
|
||||||
|
"name": "get_current_weather"
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"type": "function"
|
"type": "function"
|
||||||
@ -27,14 +26,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1710795556,
|
"created": 1712782670,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 29,
|
"completion_tokens": 37,
|
||||||
"prompt_tokens": 316,
|
"prompt_tokens": 524,
|
||||||
"total_tokens": 345
|
"total_tokens": 561
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,13 +11,12 @@
|
|||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"function": {
|
"function": {
|
||||||
"description": null,
|
"arguments": {
|
||||||
"name": "tools",
|
|
||||||
"parameters": {
|
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "New York, NY",
|
"location": "Brooklyn"
|
||||||
"num_days": 14
|
},
|
||||||
}
|
"description": null,
|
||||||
|
"name": "get_current_weather"
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"type": "function"
|
"type": "function"
|
||||||
@ -27,14 +26,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1710795557,
|
"created": 1712787937,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 29,
|
"completion_tokens": 37,
|
||||||
"prompt_tokens": 316,
|
"prompt_tokens": 524,
|
||||||
"total_tokens": 345
|
"total_tokens": 561
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,12 +11,12 @@
|
|||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"function": {
|
"function": {
|
||||||
"description": null,
|
"arguments": {
|
||||||
"name": "tools",
|
|
||||||
"parameters": {
|
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "New York, NY"
|
"location": "New York, NY"
|
||||||
}
|
},
|
||||||
|
"description": null,
|
||||||
|
"name": "get_current_weather"
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"type": "function"
|
"type": "function"
|
||||||
@ -26,14 +26,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1710795557,
|
"created": 1712787725,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.0-native",
|
"system_fingerprint": "2.0.0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 21,
|
"completion_tokens": 48,
|
||||||
"prompt_tokens": 187,
|
"prompt_tokens": 351,
|
||||||
"total_tokens": 208
|
"total_tokens": 399
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"arguments": {
|
||||||
|
"error": "One of the parameters (e.g. 'number_of_days') is not valid or is too few.",
|
||||||
|
"name": "notify_error"
|
||||||
|
},
|
||||||
|
"description": null,
|
||||||
|
"name": "default_function_name"
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1712788322,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.5-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 60,
|
||||||
|
"prompt_tokens": 535,
|
||||||
|
"total_tokens": 595
|
||||||
|
}
|
||||||
|
}
|
@ -19,7 +19,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1710795499,
|
"created": 1712788218,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -71,7 +71,6 @@ tools = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_no_tools(
|
async def test_flash_llama_grammar_no_tools(
|
||||||
flash_llama_grammar_tools, response_snapshot
|
flash_llama_grammar_tools, response_snapshot
|
||||||
@ -98,7 +97,6 @@ async def test_flash_llama_grammar_no_tools(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
@ -121,23 +119,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
assert response.choices[0].message.content == None
|
assert response.choices[0].message.content == None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"function": {
|
|
||||||
"description": None,
|
|
||||||
"name": "tools",
|
|
||||||
"parameters": {
|
|
||||||
"format": "celsius",
|
|
||||||
"location": "New York, NY",
|
|
||||||
"num_days": 14,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": {"format": "celsius", "location": "Brooklyn"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_auto(
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
@ -163,23 +156,19 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
assert response.choices[0].message.content == None
|
assert response.choices[0].message.content == None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"function": {
|
|
||||||
"description": None,
|
|
||||||
"name": "tools",
|
|
||||||
"parameters": {
|
|
||||||
"format": "celsius",
|
|
||||||
"location": "New York, NY",
|
|
||||||
"num_days": 14,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": {"format": "celsius", "location": "Brooklyn"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_choice(
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
@ -209,15 +198,15 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "tools",
|
"name": "get_current_weather",
|
||||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_stream(
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
@ -246,5 +235,47 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 20
|
assert count == 38
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=26,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a story about 3 sea creatures",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses.choices[0].message.content == None
|
||||||
|
assert responses.choices[0].message.tool_calls == [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "default_function_name",
|
||||||
|
"arguments": {
|
||||||
|
"error": "One of the parameters (e.g. 'number_of_days') is not valid or is too few.",
|
||||||
|
"name": "notify_error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
@ -79,7 +79,7 @@ impl HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||||
@ -682,7 +682,7 @@ pub(crate) struct ChatRequest {
|
|||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
fn default_tool_prompt() -> Option<String> {
|
||||||
Some(
|
Some(
|
||||||
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
@ -780,12 +780,14 @@ pub(crate) struct Tool {
|
|||||||
pub function: FunctionDefinition,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize, Default)]
|
||||||
pub(crate) struct ChatTemplateInputs<'a> {
|
pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
|
tools: Option<&'a str>,
|
||||||
|
tools_prompt: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||||
|
@ -757,23 +757,17 @@ async fn chat_completions(
|
|||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
frequency_penalty: _,
|
|
||||||
logit_bias: _,
|
|
||||||
logprobs,
|
logprobs,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
messages,
|
messages,
|
||||||
model: _,
|
|
||||||
n: _,
|
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
seed,
|
seed,
|
||||||
stop,
|
stop,
|
||||||
stream,
|
stream,
|
||||||
temperature: _,
|
|
||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
top_p: _,
|
..
|
||||||
top_logprobs: _,
|
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
@ -798,8 +792,16 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let grammar_with_prompt = tool_grammar
|
||||||
|
.as_ref()
|
||||||
|
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||||
|
|
||||||
|
let typed_grammar = grammar_with_prompt
|
||||||
|
.as_ref()
|
||||||
|
.map(|(grammar, _)| grammar.clone());
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let mut inputs = match infer.apply_chat_template(messages) {
|
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
@ -814,22 +816,6 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let grammar = if let Some(tools) = &tool_grammar {
|
|
||||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
|
||||||
(
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: e.to_string(),
|
|
||||||
error_type: "Input validation error".to_string(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
|
||||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
@ -851,7 +837,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar,
|
grammar: typed_grammar,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -934,7 +920,6 @@ async fn chat_completions(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let tool_calls = vec![ToolCall {
|
let tool_calls = vec![ToolCall {
|
||||||
id: 0,
|
id: 0,
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
|
Loading…
Reference in New Issue
Block a user