mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: improve client for tools and fix default choice
This commit is contained in:
parent
af7ebc2639
commit
7a37655d8e
@ -472,7 +472,6 @@ class AsyncClient:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
print(self.base_url)
|
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
) as session:
|
) as session:
|
||||||
|
@ -19,13 +19,24 @@ class Grammar(BaseModel):
|
|||||||
value: Union[str, dict]
|
value: Union[str, dict]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
# Id of the tool call
|
||||||
|
id: int
|
||||||
|
# Type of the tool call
|
||||||
|
type: str
|
||||||
|
# Function details of the tool call
|
||||||
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
# Role of the message sender
|
# Role of the message sender
|
||||||
role: str
|
role: str
|
||||||
# Content of the message
|
# Content of the message
|
||||||
content: str
|
content: Optional[str]
|
||||||
# Optional name of the message sender
|
# Optional name of the message sender
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
# Tool calls associated with the chat completion
|
||||||
|
tool_calls: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseModel):
|
||||||
@ -44,6 +55,8 @@ class ChatCompletionComplete(BaseModel):
|
|||||||
logprobs: Optional[Any]
|
logprobs: Optional[Any]
|
||||||
# Reason for completion
|
# Reason for completion
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
|
# Usage details of the chat completion
|
||||||
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
class ChatComplete(BaseModel):
|
class ChatComplete(BaseModel):
|
||||||
|
@ -7,11 +7,13 @@
|
|||||||
"message": {
|
"message": {
|
||||||
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant"
|
"role": "assistant",
|
||||||
}
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708626137,
|
"created": 1708957015,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
@ -5,20 +5,34 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}",
|
"content": null,
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant"
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"description": null,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "San Francisco",
|
||||||
|
"num_days": 2
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708626137,
|
"created": 1708957016,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "1.4.2-native",
|
"system_fingerprint": "1.4.2-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 29,
|
"completion_tokens": 36,
|
||||||
"prompt_tokens": 318,
|
"prompt_tokens": 313,
|
||||||
"total_tokens": 347
|
"total_tokens": 349
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"description": null,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "San Francisco",
|
||||||
|
"num_days": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1708957016,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 36,
|
||||||
|
"prompt_tokens": 313,
|
||||||
|
"total_tokens": 349
|
||||||
|
}
|
||||||
|
}
|
@ -5,20 +5,33 @@
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}",
|
"content": null,
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant"
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"description": null,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "New York, NY"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708626030,
|
"created": 1708957017,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "1.4.2-native",
|
"system_fingerprint": "1.4.2-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 21,
|
"completion_tokens": 21,
|
||||||
"prompt_tokens": 189,
|
"prompt_tokens": 184,
|
||||||
"total_tokens": 210
|
"total_tokens": 205
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,7 +71,6 @@ tools = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_no_tools(
|
async def test_flash_llama_grammar_no_tools(
|
||||||
@ -99,7 +98,6 @@ async def test_flash_llama_grammar_no_tools(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
@ -119,11 +117,59 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert len(response.choices[0].message.content) == 78
|
assert response.choices[0].message.content == None
|
||||||
assert (
|
assert response.choices[0].message.tool_calls == {
|
||||||
response.choices[0].message.content
|
"function": {
|
||||||
== """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}"""
|
"description": None,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "San Francisco",
|
||||||
|
"num_days": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
assert response.choices[0].message.content == None
|
||||||
|
assert response.choices[0].message.tool_calls == {
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "San Francisco",
|
||||||
|
"num_days": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -149,9 +195,14 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert len(response.choices[0].message.content) == 62
|
assert response.choices[0].message.content == None
|
||||||
assert (
|
assert response.choices[0].message.tool_calls == {
|
||||||
response.choices[0].message.content
|
"id": 0,
|
||||||
== """{"function":{"format": "celsius", "location": "New York, NY"}}"""
|
"type": "function",
|
||||||
)
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||||
|
},
|
||||||
|
}
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -581,6 +581,7 @@ mod deserialize_tool_choice {
|
|||||||
Err(de::Error::custom("function key not found in tool choice"))
|
Err(de::Error::custom("function key not found in tool choice"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Value::Null => Ok(Some(ToolType::OneOf)),
|
||||||
_ => Err(de::Error::custom("invalid token format")),
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user