mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: respect tool choice
This commit is contained in:
parent
3ec57acac1
commit
0e30e65822
@ -78,6 +78,7 @@ class Client:
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
@ -112,6 +113,8 @@ class Client:
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
@ -129,6 +132,7 @@ class Client:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
resp = requests.post(
|
||||
@ -412,6 +416,7 @@ class AsyncClient:
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
@ -446,6 +451,8 @@ class AsyncClient:
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
@ -463,6 +470,7 @@ class AsyncClient:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
print(self.base_url)
|
||||
async with ClientSession(
|
||||
|
@ -86,6 +86,8 @@ class ChatRequest(BaseModel):
|
||||
top_p: Optional[float] = None
|
||||
# List of tools to be used
|
||||
tools: Optional[List[Tool]] = None
|
||||
# Choice of tool to be used
|
||||
tool_choice: Optional[str] = None
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
|
@ -5,13 +5,13 @@
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city's cold penchant makes it a popular winter destination, and meteorologists predict \"bomb cyclone\" conditions in the year 2021. - Due to",
|
||||
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
||||
"name": null,
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1708623190,
|
||||
"created": 1708626137,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
@ -5,20 +5,20 @@
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}",
|
||||
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}",
|
||||
"name": null,
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1708623212,
|
||||
"created": 1708626137,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 33,
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 318,
|
||||
"total_tokens": 351
|
||||
"total_tokens": 347
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,24 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}",
|
||||
"name": null,
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1708626030,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 189,
|
||||
"total_tokens": 210
|
||||
}
|
||||
}
|
@ -73,12 +73,12 @@ tools = [
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_no_tools_regex(
|
||||
async def test_flash_llama_grammar_no_tools(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=0,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
@ -93,19 +93,17 @@ async def test_flash_llama_grammar_no_tools_regex(
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== 'As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city\'s cold penchant makes it a popular winter destination, and meteorologists predict "bomb cyclone" conditions in the year 2021. - Due to'
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_regex(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=0,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
@ -119,9 +117,39 @@ async def test_flash_llama_grammar_tools_regex(
|
||||
},
|
||||
],
|
||||
)
|
||||
assert len(response.choices[0].message.content) == 81
|
||||
assert len(response.choices[0].message.content) == 78
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}"""
|
||||
== """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}"""
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert len(response.choices[0].message.content) == 62
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== """{"function":{"format": "celsius", "location": "New York, NY"}}"""
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
@ -526,6 +526,11 @@ pub(crate) struct ChatRequest {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
@ -536,10 +541,9 @@ pub struct Tools {
|
||||
pub any_of: Vec<FunctionRef>,
|
||||
}
|
||||
|
||||
// add traut to convert to serde_json::Value for tools
|
||||
// Allows Tools to be converted to a valid JSON schema object
|
||||
impl From<Tools> for serde_json::Value {
|
||||
fn from(tools: Tools) -> Self {
|
||||
println!("tools: {:?}", tools);
|
||||
let mut map = serde_json::Map::new();
|
||||
let mut functions = serde_json::Map::new();
|
||||
for (name, value) in tools.function {
|
||||
|
@ -601,10 +601,31 @@ async fn chat_completions(
|
||||
|
||||
// if theres a tools object, we need to decompose it and use the function name as the key
|
||||
// and the parameters as the value in the "$functions" object.
|
||||
let grammar = if let Some(req_tools) = &req.tools {
|
||||
let grammar = if let Some(ref req_tools) = &req.tools {
|
||||
// get the tool_choice if there is one
|
||||
let tool_choice = &req.tool_choice;
|
||||
let tools_to_use = if let Some(tool_choice) = tool_choice {
|
||||
// get the tool based on the tool_choice
|
||||
let tool = req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *tool_choice)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Input validation error".to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
vec![tool.clone()]
|
||||
} else {
|
||||
req_tools.clone()
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = {
|
||||
let mut tools = HashMap::new();
|
||||
for tool in req_tools {
|
||||
for tool in &tools_to_use {
|
||||
let func = tool.function.clone();
|
||||
let name = func.name;
|
||||
let parameters = match func.parameters.as_object() {
|
||||
@ -627,7 +648,7 @@ async fn chat_completions(
|
||||
|
||||
let tools = Tools {
|
||||
function: functions,
|
||||
any_of: req_tools
|
||||
any_of: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef::new(&tool.function.name))
|
||||
.collect(),
|
||||
|
Loading…
Reference in New Issue
Block a user