fix: adjust tool choice none logic, add test and small refactors

This commit is contained in:
David Holtz 2024-10-18 15:36:40 +00:00 committed by drbh
parent b5bf5b32ad
commit 407531708e
4 changed files with 81 additions and 42 deletions

View File

@ -1015,7 +1015,7 @@
"$ref": "#/components/schemas/GrammarType" "$ref": "#/components/schemas/GrammarType"
} }
], ],
"default": "auto", "default": "null",
"nullable": true "nullable": true
}, },
"seed": { "seed": {
@ -1058,7 +1058,7 @@
"$ref": "#/components/schemas/ToolChoice" "$ref": "#/components/schemas/ToolChoice"
} }
], ],
"default": "null", "default": "auto",
"nullable": true "nullable": true
}, },
"tool_prompt": { "tool_prompt": {

View File

@ -371,6 +371,47 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
assert last_response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="none",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 100
print(content_generated)
assert (
content_generated
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
)
assert last_response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(

View File

@ -20,12 +20,7 @@ impl ToolGrammar {
pub fn apply( pub fn apply(
tools: Vec<Tool>, tools: Vec<Tool>,
tool_choice: ToolChoice, tool_choice: ToolChoice,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> { ) -> Result<Option<(Vec<Tool>, JsonSchemaTool)>, InferError> {
// if no tools are provided, we return None and an empty vec
if tools.is_empty() {
return Ok((Vec::with_capacity(0), None));
}
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolChoice::Function(function) => { ToolChoice::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?] vec![Self::find_tool_by_name(&tools, &function.name)?]
@ -57,9 +52,14 @@ impl ToolGrammar {
})) }))
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
ToolChoice::NoTool => Vec::with_capacity(0), ToolChoice::NoTool => vec![],
}; };
// if no tools are provided or if the user has selected the no_tool option, return None
if tools_to_use.is_empty() {
return Ok(None);
}
let functions: HashMap<String, serde_json::Value> = tools_to_use let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter() .iter()
.map(|tool| { .map(|tool| {
@ -106,22 +106,18 @@ impl ToolGrammar {
}) })
.collect(); .collect();
let tool_schema = if tools_to_use.is_empty() { let tool_schema = JsonSchemaTool {
None functions_map: FunctionsMap { functions },
} else { properties: Properties {
Some(JsonSchemaTool { function: tools_to_use
functions_map: FunctionsMap { functions }, .iter()
properties: Properties { .map(|tool| FunctionRef {
function: tools_to_use ref_path: format!("#/$functions/{}", tool.function.name.clone()),
.iter() })
.map(|tool| FunctionRef { .collect(),
ref_path: format!("#/$functions/{}", tool.function.name.clone()), },
})
.collect(),
},
})
}; };
Ok((tools_to_use, tool_schema)) Ok(Some((tools_to_use, tool_schema)))
} }
} }

View File

@ -892,14 +892,14 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "auto", example = "auto")]
pub tool_choice: Option<ToolChoice>, pub tool_choice: ToolChoice,
/// Response format constraints for the generation. /// Response format constraints for the generation.
/// ///
/// NOTE: A request can use `response_format` OR `tools` but not both. /// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "auto", example = "auto")] #[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>, pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template /// A guideline to be used in the chat_template
@ -946,8 +946,6 @@ impl ChatRequest {
Some(temperature) if temperature == 0.0 => (false, None), Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other), other => (true, other),
}; };
// if no tool_choice is set, set default (Auto)
let tool_choice = tool_choice.unwrap_or_default();
if response_format.is_some() && tools.is_some() { if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError( return Err(InferError::ToolError(
@ -962,18 +960,22 @@ impl ChatRequest {
} }
None => { None => {
if let Some(tools) = tools { if let Some(tools) = tools {
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; match ToolGrammar::apply(tools, tool_choice)? {
Some((updated_tools, tool_schema)) => {
let grammar = tool_schema let grammar = GrammarType::Json(serde_json::json!(tool_schema));
.as_ref() let inputs: String = infer.apply_chat_template(
.map(|t| GrammarType::Json(serde_json::json!(t))); guideline,
messages,
let inputs: String = infer.apply_chat_template( Some((updated_tools, tool_prompt)),
guideline, )?;
messages, (inputs, Some(grammar), true)
Some((updated_tools, tool_prompt)), }
)?; None => {
(inputs, grammar, tool_schema.is_some()) // same as if no response_format or tools are set
let inputs = infer.apply_chat_template(guideline, messages, None)?;
(inputs, None, false)
}
}
} else { } else {
// if no response_format or tools are set simply apply the chat template to generate inputs // if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?; let inputs = infer.apply_chat_template(guideline, messages, None)?;
@ -1046,7 +1048,7 @@ pub enum ToolChoice {
#[default] #[default]
Auto, Auto,
/// Means the model will not call any tool and instead generates a message. /// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")] #[serde(rename = "none")]
NoTool, NoTool,
/// Means the model must call one or more tools. /// Means the model must call one or more tools.
Required, Required,