fix: enable defs references in tool calls

This commit is contained in:
drbh 2025-07-07 14:35:04 +00:00
parent fc2405c549
commit 71fbe88a30
3 changed files with 114 additions and 1 deletions

View File

@ -0,0 +1,36 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "{\"weather\":\"sunny\"}",
"description": null,
"name": "classify_weather"
},
"id": "0",
"type": "function"
}
]
},
"usage": null
}
],
"created": 1751896977,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.3.4-dev0-native",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 196,
"total_tokens": 216
}
}

View File

@ -0,0 +1,51 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma3_handle(launcher):
with launcher("google/gemma-3-4b-it", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma3(flash_gemma3_handle):
await flash_gemma3_handle.health(300)
return flash_gemma3_handle.client
async def test_flash_gemma3_defs(flash_gemma3, response_snapshot):
response = await flash_gemma3.chat(
messages=[
{
"content": "Classify the weather: It's sunny outside with clear skies",
"role": "user",
}
],
tools=[
{
"type": "function",
"function": {
"name": "classify_weather",
"description": "Classify weather conditions",
"parameters": {
"$defs": {
"WeatherType": {
"enum": ["sunny", "cloudy", "rainy", "snowy"],
"type": "string",
}
},
"properties": {"weather": {"$ref": "#/$defs/WeatherType"}},
"required": ["weather"],
"type": "object",
},
},
}
],
tool_choice="auto",
max_tokens=100,
seed=42,
)
assert response.choices[0].message.tool_calls[0]["function"]["name"] == "classify_weather"
assert response.choices[0].message.tool_calls[0]["function"]["arguments"] == '{"weather":"sunny"}'
assert response == response_snapshot

View File

@ -72,6 +72,7 @@ impl ToolGrammar {
Value::String(func.description.unwrap_or_default()),
);
let mut defs = Map::new();
let mut properties = Map::new();
let mut required = vec![Value::String("_name".to_string())];
@ -85,11 +86,35 @@ impl ToolGrammar {
if let Value::Object(args) = func.arguments {
if let Some(Value::Object(props)) = args.get("properties") {
properties.extend(props.clone());
let mut updated_props = Map::new();
// Update $ref paths in properties by iterating through
for (key, value) in props.iter() {
let updated_value = match value {
Value::Object(obj) if obj.contains_key("$ref") => {
let mut new_obj = obj.clone();
if let Some(Value::String(ref_str)) = new_obj.get("$ref") {
if ref_str.starts_with("#/$defs/") {
// Replace $defs with $functions/{func.name}/$defs to handle
// function-specific definitions
new_obj.insert("$ref".to_string(), Value::String(
ref_str.replace("#/$defs/", &format!("#/$functions/{}/$defs/", func.name))
));
}
}
Value::Object(new_obj)
}
_ => value.clone(),
};
updated_props.insert(key.clone(), updated_value);
}
properties.extend(updated_props);
}
if let Some(Value::Array(reqs)) = args.get("required") {
required.extend(reqs.clone());
}
if let Some(Value::Object(definitions)) = args.get("$defs") {
defs.extend(definitions.clone());
}
params.insert(
"additionalProperties".to_string(),
Value::Bool(
@ -101,6 +126,7 @@ impl ToolGrammar {
params.insert("properties".to_string(), Value::Object(properties));
params.insert("required".to_string(), Value::Array(required));
params.insert("$defs".to_string(), Value::Object(defs));
(func.name, Value::Object(params))
})