mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: adjust tool choice none logic, add test and small refactors
This commit is contained in:
parent
b5bf5b32ad
commit
407531708e
@ -1015,7 +1015,7 @@
|
||||
"$ref": "#/components/schemas/GrammarType"
|
||||
}
|
||||
],
|
||||
"default": "auto",
|
||||
"default": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"seed": {
|
||||
@ -1058,7 +1058,7 @@
|
||||
"$ref": "#/components/schemas/ToolChoice"
|
||||
}
|
||||
],
|
||||
"default": "null",
|
||||
"default": "auto",
|
||||
"nullable": true
|
||||
},
|
||||
"tool_prompt": {
|
||||
|
@ -371,6 +371,47 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="none",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
content_generated = ""
|
||||
last_response = None
|
||||
async for response in responses:
|
||||
count += 1
|
||||
content_generated += response.choices[0].delta.content
|
||||
last_response = response
|
||||
assert response.choices[0].delta.tool_calls is None
|
||||
|
||||
assert count == 100
|
||||
print(content_generated)
|
||||
assert (
|
||||
content_generated
|
||||
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
|
||||
)
|
||||
assert last_response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||
|
@ -20,12 +20,7 @@ impl ToolGrammar {
|
||||
pub fn apply(
|
||||
tools: Vec<Tool>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||
// if no tools are provided, we return None and an empty vec
|
||||
if tools.is_empty() {
|
||||
return Ok((Vec::with_capacity(0), None));
|
||||
}
|
||||
|
||||
) -> Result<Option<(Vec<Tool>, JsonSchemaTool)>, InferError> {
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolChoice::Function(function) => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
@ -57,9 +52,14 @@ impl ToolGrammar {
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
ToolChoice::NoTool => Vec::with_capacity(0),
|
||||
ToolChoice::NoTool => vec![],
|
||||
};
|
||||
|
||||
// if no tools are provided or if the user has selected the no_tool option, return None
|
||||
if tools_to_use.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
@ -106,22 +106,18 @@ impl ToolGrammar {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tool_schema = if tools_to_use.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(JsonSchemaTool {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
})
|
||||
let tool_schema = JsonSchemaTool {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
Ok((tools_to_use, tool_schema))
|
||||
Ok(Some((tools_to_use, tool_schema)))
|
||||
}
|
||||
}
|
||||
|
@ -892,14 +892,14 @@ pub(crate) struct ChatRequest {
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[schema(nullable = true, default = "auto", example = "auto")]
|
||||
pub tool_choice: ToolChoice,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
///
|
||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "auto", example = "auto")]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub response_format: Option<GrammarType>,
|
||||
|
||||
/// A guideline to be used in the chat_template
|
||||
@ -946,8 +946,6 @@ impl ChatRequest {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
// if no tool_choice is set, set default (Auto)
|
||||
let tool_choice = tool_choice.unwrap_or_default();
|
||||
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err(InferError::ToolError(
|
||||
@ -962,18 +960,22 @@ impl ChatRequest {
|
||||
}
|
||||
None => {
|
||||
if let Some(tools) = tools {
|
||||
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||
|
||||
let grammar = tool_schema
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt)),
|
||||
)?;
|
||||
(inputs, grammar, tool_schema.is_some())
|
||||
match ToolGrammar::apply(tools, tool_choice)? {
|
||||
Some((updated_tools, tool_schema)) => {
|
||||
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt)),
|
||||
)?;
|
||||
(inputs, Some(grammar), true)
|
||||
}
|
||||
None => {
|
||||
// same as if no response_format or tools are set
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
(inputs, None, false)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
@ -1046,7 +1048,7 @@ pub enum ToolChoice {
|
||||
#[default]
|
||||
Auto,
|
||||
/// Means the model will not call any tool and instead generates a message.
|
||||
#[schema(rename = "none")]
|
||||
#[serde(rename = "none")]
|
||||
NoTool,
|
||||
/// Means the model must call one or more tools.
|
||||
Required,
|
||||
|
Loading…
Reference in New Issue
Block a user