mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: respect tool choice
This commit is contained in:
parent
3ec57acac1
commit
0e30e65822
@ -78,6 +78,7 @@ class Client:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
@ -112,6 +113,8 @@ class Client:
|
|||||||
higher are kept for generation
|
higher are kept for generation
|
||||||
tools (`List[Tool]`):
|
tools (`List[Tool]`):
|
||||||
List of tools to use
|
List of tools to use
|
||||||
|
tool_choice (`str`):
|
||||||
|
The tool to use
|
||||||
|
|
||||||
"""
|
"""
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
@ -129,6 +132,7 @@ class Client:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
@ -412,6 +416,7 @@ class AsyncClient:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
@ -446,6 +451,8 @@ class AsyncClient:
|
|||||||
higher are kept for generation
|
higher are kept for generation
|
||||||
tools (`List[Tool]`):
|
tools (`List[Tool]`):
|
||||||
List of tools to use
|
List of tools to use
|
||||||
|
tool_choice (`str`):
|
||||||
|
The tool to use
|
||||||
|
|
||||||
"""
|
"""
|
||||||
request = ChatRequest(
|
request = ChatRequest(
|
||||||
@ -463,6 +470,7 @@ class AsyncClient:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
print(self.base_url)
|
print(self.base_url)
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
|
@ -86,6 +86,8 @@ class ChatRequest(BaseModel):
|
|||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
# List of tools to be used
|
# List of tools to be used
|
||||||
tools: Optional[List[Tool]] = None
|
tools: Optional[List[Tool]] = None
|
||||||
|
# Choice of tool to be used
|
||||||
|
tool_choice: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
|
@ -5,13 +5,13 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city's cold penchant makes it a popular winter destination, and meteorologists predict \"bomb cyclone\" conditions in the year 2021. - Due to",
|
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant"
|
"role": "assistant"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708623190,
|
"created": 1708626137,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -5,20 +5,20 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}",
|
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant"
|
"role": "assistant"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708623212,
|
"created": 1708626137,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "1.4.2-native",
|
"system_fingerprint": "1.4.2-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 33,
|
"completion_tokens": 29,
|
||||||
"prompt_tokens": 318,
|
"prompt_tokens": 318,
|
||||||
"total_tokens": 351
|
"total_tokens": 347
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1708626030,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 21,
|
||||||
|
"prompt_tokens": 189,
|
||||||
|
"total_tokens": 210
|
||||||
|
}
|
||||||
|
}
|
@ -73,12 +73,12 @@ tools = [
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_no_tools_regex(
|
async def test_flash_llama_grammar_no_tools(
|
||||||
flash_llama_grammar_tools, response_snapshot
|
flash_llama_grammar_tools, response_snapshot
|
||||||
):
|
):
|
||||||
response = await flash_llama_grammar_tools.chat(
|
response = await flash_llama_grammar_tools.chat(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=0,
|
seed=1,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -93,19 +93,17 @@ async def test_flash_llama_grammar_no_tools_regex(
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
response.choices[0].message.content
|
response.choices[0].message.content
|
||||||
== 'As an up-to-date news station, our team has access to the latest information on weather conditions in Brooklyn, New York. Here is what we have learned so far:\n\n- Located in New York City, Brooklyn has a history of harsh weather patterns, especially in winter. The city\'s cold penchant makes it a popular winter destination, and meteorologists predict "bomb cyclone" conditions in the year 2021. - Due to'
|
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_regex(
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
flash_llama_grammar_tools, response_snapshot
|
|
||||||
):
|
|
||||||
response = await flash_llama_grammar_tools.chat(
|
response = await flash_llama_grammar_tools.chat(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=0,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
presence_penalty=-1.1,
|
presence_penalty=-1.1,
|
||||||
messages=[
|
messages=[
|
||||||
@ -119,9 +117,39 @@ async def test_flash_llama_grammar_tools_regex(
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert len(response.choices[0].message.content) == 81
|
assert len(response.choices[0].message.content) == 78
|
||||||
assert (
|
assert (
|
||||||
response.choices[0].message.content
|
response.choices[0].message.content
|
||||||
== """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}"""
|
== """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}"""
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="get_current_weather",
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert len(response.choices[0].message.content) == 62
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== """{"function":{"format": "celsius", "location": "New York, NY"}}"""
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -526,6 +526,11 @@ pub(crate) struct ChatRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tool_choice: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||||
@ -536,10 +541,9 @@ pub struct Tools {
|
|||||||
pub any_of: Vec<FunctionRef>,
|
pub any_of: Vec<FunctionRef>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// add traut to convert to serde_json::Value for tools
|
// Allows Tools to be converted to a valid JSON schema object
|
||||||
impl From<Tools> for serde_json::Value {
|
impl From<Tools> for serde_json::Value {
|
||||||
fn from(tools: Tools) -> Self {
|
fn from(tools: Tools) -> Self {
|
||||||
println!("tools: {:?}", tools);
|
|
||||||
let mut map = serde_json::Map::new();
|
let mut map = serde_json::Map::new();
|
||||||
let mut functions = serde_json::Map::new();
|
let mut functions = serde_json::Map::new();
|
||||||
for (name, value) in tools.function {
|
for (name, value) in tools.function {
|
||||||
|
@ -601,10 +601,31 @@ async fn chat_completions(
|
|||||||
|
|
||||||
// if theres a tools object, we need to decompose it and use the function name as the key
|
// if theres a tools object, we need to decompose it and use the function name as the key
|
||||||
// and the parameters as the value in the "$functions" object.
|
// and the parameters as the value in the "$functions" object.
|
||||||
let grammar = if let Some(req_tools) = &req.tools {
|
let grammar = if let Some(ref req_tools) = &req.tools {
|
||||||
|
// get the tool_choice if there is one
|
||||||
|
let tool_choice = &req.tool_choice;
|
||||||
|
let tools_to_use = if let Some(tool_choice) = tool_choice {
|
||||||
|
// get the tool based on the tool_choice
|
||||||
|
let tool = req_tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == *tool_choice)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Input validation error".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
vec![tool.clone()]
|
||||||
|
} else {
|
||||||
|
req_tools.clone()
|
||||||
|
};
|
||||||
|
|
||||||
let functions: HashMap<String, Value> = {
|
let functions: HashMap<String, Value> = {
|
||||||
let mut tools = HashMap::new();
|
let mut tools = HashMap::new();
|
||||||
for tool in req_tools {
|
for tool in &tools_to_use {
|
||||||
let func = tool.function.clone();
|
let func = tool.function.clone();
|
||||||
let name = func.name;
|
let name = func.name;
|
||||||
let parameters = match func.parameters.as_object() {
|
let parameters = match func.parameters.as_object() {
|
||||||
@ -627,7 +648,7 @@ async fn chat_completions(
|
|||||||
|
|
||||||
let tools = Tools {
|
let tools = Tools {
|
||||||
function: functions,
|
function: functions,
|
||||||
any_of: req_tools
|
any_of: tools_to_use
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool| FunctionRef::new(&tool.function.name))
|
.map(|tool| FunctionRef::new(&tool.function.name))
|
||||||
.collect(),
|
.collect(),
|
||||||
|
Loading…
Reference in New Issue
Block a user