mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix: enable defs references in tool calls
This commit is contained in:
parent
fc2405c549
commit
71fbe88a30
@ -0,0 +1,36 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"arguments": "{\"weather\":\"sunny\"}",
|
||||||
|
"description": null,
|
||||||
|
"name": "classify_weather"
|
||||||
|
},
|
||||||
|
"id": "0",
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1751896977,
|
||||||
|
"id": "",
|
||||||
|
"model": "google/gemma-3-4b-it",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.3.4-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 196,
|
||||||
|
"total_tokens": 216
|
||||||
|
}
|
||||||
|
}
|
51
integration-tests/models/test_tool_def.py
Normal file
51
integration-tests/models/test_tool_def.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma3_handle(launcher):
|
||||||
|
with launcher("google/gemma-3-4b-it", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma3(flash_gemma3_handle):
|
||||||
|
await flash_gemma3_handle.health(300)
|
||||||
|
return flash_gemma3_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_gemma3_defs(flash_gemma3, response_snapshot):
|
||||||
|
response = await flash_gemma3.chat(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"content": "Classify the weather: It's sunny outside with clear skies",
|
||||||
|
"role": "user",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "classify_weather",
|
||||||
|
"description": "Classify weather conditions",
|
||||||
|
"parameters": {
|
||||||
|
"$defs": {
|
||||||
|
"WeatherType": {
|
||||||
|
"enum": ["sunny", "cloudy", "rainy", "snowy"],
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"properties": {"weather": {"$ref": "#/$defs/WeatherType"}},
|
||||||
|
"required": ["weather"],
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_choice="auto",
|
||||||
|
max_tokens=100,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.tool_calls[0]["function"]["name"] == "classify_weather"
|
||||||
|
assert response.choices[0].message.tool_calls[0]["function"]["arguments"] == '{"weather":"sunny"}'
|
||||||
|
assert response == response_snapshot
|
@ -72,6 +72,7 @@ impl ToolGrammar {
|
|||||||
Value::String(func.description.unwrap_or_default()),
|
Value::String(func.description.unwrap_or_default()),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let mut defs = Map::new();
|
||||||
let mut properties = Map::new();
|
let mut properties = Map::new();
|
||||||
let mut required = vec![Value::String("_name".to_string())];
|
let mut required = vec![Value::String("_name".to_string())];
|
||||||
|
|
||||||
@ -85,11 +86,35 @@ impl ToolGrammar {
|
|||||||
|
|
||||||
if let Value::Object(args) = func.arguments {
|
if let Value::Object(args) = func.arguments {
|
||||||
if let Some(Value::Object(props)) = args.get("properties") {
|
if let Some(Value::Object(props)) = args.get("properties") {
|
||||||
properties.extend(props.clone());
|
let mut updated_props = Map::new();
|
||||||
|
// Update $ref paths in properties by iterating through
|
||||||
|
for (key, value) in props.iter() {
|
||||||
|
let updated_value = match value {
|
||||||
|
Value::Object(obj) if obj.contains_key("$ref") => {
|
||||||
|
let mut new_obj = obj.clone();
|
||||||
|
if let Some(Value::String(ref_str)) = new_obj.get("$ref") {
|
||||||
|
if ref_str.starts_with("#/$defs/") {
|
||||||
|
// Replace $defs with $functions/{func.name}/$defs to handle
|
||||||
|
// function-specific definitions
|
||||||
|
new_obj.insert("$ref".to_string(), Value::String(
|
||||||
|
ref_str.replace("#/$defs/", &format!("#/$functions/{}/$defs/", func.name))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::Object(new_obj)
|
||||||
|
}
|
||||||
|
_ => value.clone(),
|
||||||
|
};
|
||||||
|
updated_props.insert(key.clone(), updated_value);
|
||||||
|
}
|
||||||
|
properties.extend(updated_props);
|
||||||
}
|
}
|
||||||
if let Some(Value::Array(reqs)) = args.get("required") {
|
if let Some(Value::Array(reqs)) = args.get("required") {
|
||||||
required.extend(reqs.clone());
|
required.extend(reqs.clone());
|
||||||
}
|
}
|
||||||
|
if let Some(Value::Object(definitions)) = args.get("$defs") {
|
||||||
|
defs.extend(definitions.clone());
|
||||||
|
}
|
||||||
params.insert(
|
params.insert(
|
||||||
"additionalProperties".to_string(),
|
"additionalProperties".to_string(),
|
||||||
Value::Bool(
|
Value::Bool(
|
||||||
@ -101,6 +126,7 @@ impl ToolGrammar {
|
|||||||
|
|
||||||
params.insert("properties".to_string(), Value::Object(properties));
|
params.insert("properties".to_string(), Value::Object(properties));
|
||||||
params.insert("required".to_string(), Value::Array(required));
|
params.insert("required".to_string(), Value::Array(required));
|
||||||
|
params.insert("$defs".to_string(), Value::Object(defs));
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
(func.name, Value::Object(params))
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user