feat: respect tool choice

This commit is contained in:
drbh 2024-02-22 18:26:49 +00:00
parent 3ec57acac1
commit 0e30e65822
8 changed files with 107 additions and 20 deletions

View File

@ -78,6 +78,7 @@ class Client:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
):
"""
Given a list of messages, generate a response asynchronously
@ -112,6 +113,8 @@ class Client:
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_choice (`str`):
The tool to use
"""
request = ChatRequest(
@ -129,6 +132,7 @@ class Client:
temperature=temperature,
top_p=top_p,
tools=tools,
tool_choice=tool_choice,
)
resp = requests.post(
@ -412,6 +416,7 @@ class AsyncClient:
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
):
"""
Given a list of messages, generate a response asynchronously
@ -446,6 +451,8 @@ class AsyncClient:
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_choice (`str`):
The tool to use
"""
request = ChatRequest(
@ -463,6 +470,7 @@ class AsyncClient:
temperature=temperature,
top_p=top_p,
tools=tools,
tool_choice=tool_choice,
)
print(self.base_url)
async with ClientSession(

View File

@ -86,6 +86,8 @@ class ChatRequest(BaseModel):
top_p: Optional[float] = None
# List of tools to be used
tools: Optional[List[Tool]] = None
# Choice of tool to be used
tool_choice: Optional[str] = None
class Parameters(BaseModel):

View File

@ -5,13 +5,13 @@
"index": 0,
"logprobs": null,
"message": {
"content": "As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city's cold penchant makes it a popular winter destination, and meteorologists predict \"bomb cyclone\" conditions in the year 2021. - Due to",
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
"name": null,
"role": "assistant"
}
}
],
"created": 1708623190,
"created": 1708626137,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -5,20 +5,20 @@
"index": 0,
"logprobs": null,
"message": {
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}",
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}",
"name": null,
"role": "assistant"
}
}
],
"created": 1708623212,
"created": 1708626137,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 33,
"completion_tokens": 29,
"prompt_tokens": 318,
"total_tokens": 351
"total_tokens": 347
}
}

View File

@ -0,0 +1,24 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}",
"name": null,
"role": "assistant"
}
}
],
"created": 1708626030,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 189,
"total_tokens": 210
}
}

View File

@ -73,12 +73,12 @@ tools = [
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_no_tools_regex(
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=0,
seed=1,
messages=[
{
"role": "system",
@ -93,19 +93,17 @@ async def test_flash_llama_grammar_no_tools_regex(
assert (
response.choices[0].message.content
== 'As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city\'s cold penchant makes it a popular winter destination, and meteorologists predict "bomb cyclone" conditions in the year 2021. - Due to'
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_regex(
flash_llama_grammar_tools, response_snapshot
):
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=0,
seed=1,
tools=tools,
presence_penalty=-1.1,
messages=[
@ -119,9 +117,39 @@ async def test_flash_llama_grammar_tools_regex(
},
],
)
assert len(response.choices[0].message.content) == 81
assert len(response.choices[0].message.content) == 78
assert (
response.choices[0].message.content
== """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}"""
== """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}"""
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert len(response.choices[0].message.content) == 62
assert (
response.choices[0].message.content
== """{"function":{"format": "celsius", "location": "New York, NY"}}"""
)
assert response == response_snapshot

View File

@ -526,6 +526,11 @@ pub(crate) struct ChatRequest {
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
@ -536,10 +541,9 @@ pub struct Tools {
pub any_of: Vec<FunctionRef>,
}
// add traut to convert to serde_json::Value for tools
// Allows Tools to be converted to a valid JSON schema object
impl From<Tools> for serde_json::Value {
fn from(tools: Tools) -> Self {
println!("tools: {:?}", tools);
let mut map = serde_json::Map::new();
let mut functions = serde_json::Map::new();
for (name, value) in tools.function {

View File

@ -601,10 +601,31 @@ async fn chat_completions(
// if theres a tools object, we need to decompose it and use the function name as the key
// and the parameters as the value in the "$functions" object.
let grammar = if let Some(req_tools) = &req.tools {
let grammar = if let Some(ref req_tools) = &req.tools {
// get the tool_choice if there is one
let tool_choice = &req.tool_choice;
let tools_to_use = if let Some(tool_choice) = tool_choice {
// get the tool based on the tool_choice
let tool = req_tools
.iter()
.find(|tool| tool.function.name == *tool_choice)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
vec![tool.clone()]
} else {
req_tools.clone()
};
let functions: HashMap<String, Value> = {
let mut tools = HashMap::new();
for tool in req_tools {
for tool in &tools_to_use {
let func = tool.function.clone();
let name = func.name;
let parameters = match func.parameters.as_object() {
@ -627,7 +648,7 @@ async fn chat_completions(
let tools = Tools {
function: functions,
any_of: req_tools
any_of: tools_to_use
.iter()
.map(|tool| FunctionRef::new(&tool.function.name))
.collect(),