mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
Merge b6540cea50
into 06d9d88b95
This commit is contained in:
commit
5cfe2bbae1
@ -0,0 +1,36 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": "{\"weather\":\"sunny\"}",
|
||||
"description": null,
|
||||
"name": "classify_weather"
|
||||
},
|
||||
"id": "0",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1751896977,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.3.4-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 196,
|
||||
"total_tokens": 216
|
||||
}
|
||||
}
|
57
integration-tests/models/test_tool_def.py
Normal file
57
integration-tests/models/test_tool_def.py
Normal file
@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gemma3_handle(launcher):
|
||||
with launcher("google/gemma-3-4b-it", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gemma3(flash_gemma3_handle):
|
||||
await flash_gemma3_handle.health(300)
|
||||
return flash_gemma3_handle.client
|
||||
|
||||
|
||||
async def test_flash_gemma3_defs(flash_gemma3, response_snapshot):
|
||||
response = await flash_gemma3.chat(
|
||||
messages=[
|
||||
{
|
||||
"content": "Classify the weather: It's sunny outside with clear skies",
|
||||
"role": "user",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "classify_weather",
|
||||
"description": "Classify weather conditions",
|
||||
"parameters": {
|
||||
"$defs": {
|
||||
"WeatherType": {
|
||||
"enum": ["sunny", "cloudy", "rainy", "snowy"],
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"properties": {"weather": {"$ref": "#/$defs/WeatherType"}},
|
||||
"required": ["weather"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
max_tokens=100,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.tool_calls[0]["function"]["name"]
|
||||
== "classify_weather"
|
||||
)
|
||||
assert (
|
||||
response.choices[0].message.tool_calls[0]["function"]["arguments"]
|
||||
== '{"weather":"sunny"}'
|
||||
)
|
||||
assert response == response_snapshot
|
@ -72,6 +72,7 @@ impl ToolGrammar {
|
||||
Value::String(func.description.unwrap_or_default()),
|
||||
);
|
||||
|
||||
let mut defs = Map::new();
|
||||
let mut properties = Map::new();
|
||||
let mut required = vec![Value::String("_name".to_string())];
|
||||
|
||||
@ -85,11 +86,39 @@ impl ToolGrammar {
|
||||
|
||||
if let Value::Object(args) = func.arguments {
|
||||
if let Some(Value::Object(props)) = args.get("properties") {
|
||||
properties.extend(props.clone());
|
||||
let mut updated_props = Map::new();
|
||||
// Update $ref paths in properties by iterating through
|
||||
for (key, value) in props.iter() {
|
||||
let updated_value = match value {
|
||||
Value::Object(obj) if obj.contains_key("$ref") => {
|
||||
let mut new_obj = obj.clone();
|
||||
if let Some(Value::String(ref_str)) = new_obj.get("$ref") {
|
||||
if ref_str.starts_with("#/$defs/") {
|
||||
// Replace $defs with $functions/{func.name}/$defs to handle
|
||||
// function-specific definitions
|
||||
new_obj.insert(
|
||||
"$ref".to_string(),
|
||||
Value::String(ref_str.replace(
|
||||
"#/$defs/",
|
||||
&format!("#/$functions/{}/$defs/", func.name),
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
Value::Object(new_obj)
|
||||
}
|
||||
_ => value.clone(),
|
||||
};
|
||||
updated_props.insert(key.clone(), updated_value);
|
||||
}
|
||||
properties.extend(updated_props);
|
||||
}
|
||||
if let Some(Value::Array(reqs)) = args.get("required") {
|
||||
required.extend(reqs.clone());
|
||||
}
|
||||
if let Some(Value::Object(definitions)) = args.get("$defs") {
|
||||
defs.extend(definitions.clone());
|
||||
}
|
||||
params.insert(
|
||||
"additionalProperties".to_string(),
|
||||
Value::Bool(
|
||||
@ -101,6 +130,7 @@ impl ToolGrammar {
|
||||
|
||||
params.insert("properties".to_string(), Value::Object(properties));
|
||||
params.insert("required".to_string(), Value::Array(required));
|
||||
params.insert("$defs".to_string(), Value::Object(defs));
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user