feat: improve client for tools and fix default choice

This commit is contained in:
drbh 2024-02-26 14:18:09 +00:00
parent af7ebc2639
commit 7a37655d8e
8 changed files with 160 additions and 29 deletions

View File

@ -472,7 +472,6 @@ class AsyncClient:
tools=tools, tools=tools,
tool_choice=tool_choice, tool_choice=tool_choice,
) )
print(self.base_url)
async with ClientSession( async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session: ) as session:

View File

@ -19,13 +19,24 @@ class Grammar(BaseModel):
value: Union[str, dict] value: Union[str, dict]
class ToolCall(BaseModel):
# Id of the tool call
id: int
# Type of the tool call
type: str
# Function details of the tool call
function: dict
class Message(BaseModel): class Message(BaseModel):
# Role of the message sender # Role of the message sender
role: str role: str
# Content of the message # Content of the message
content: str content: Optional[str]
# Optional name of the message sender # Optional name of the message sender
name: Optional[str] = None name: Optional[str] = None
# Tool calls associated with the chat completion
tool_calls: Optional[Any] = None
class Tool(BaseModel): class Tool(BaseModel):
@ -44,6 +55,8 @@ class ChatCompletionComplete(BaseModel):
logprobs: Optional[Any] logprobs: Optional[Any]
# Reason for completion # Reason for completion
finish_reason: str finish_reason: str
# Usage details of the chat completion
usage: Any
class ChatComplete(BaseModel): class ChatComplete(BaseModel):

View File

@ -7,11 +7,13 @@
"message": { "message": {
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
"name": null, "name": null,
"role": "assistant" "role": "assistant",
} "tool_calls": null
},
"usage": null
} }
], ],
"created": 1708626137, "created": 1708957015,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",

View File

@ -5,20 +5,34 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\", \"num_days\": 14}}", "content": null,
"name": null, "name": null,
"role": "assistant" "role": "assistant",
} "tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2
}
},
"id": 0,
"type": "function"
}
},
"usage": null
} }
], ],
"created": 1708626137, "created": 1708957016,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.2-native", "system_fingerprint": "1.4.2-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 36,
"prompt_tokens": 318, "prompt_tokens": 313,
"total_tokens": 347 "total_tokens": 349
} }
} }

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1708957016,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 313,
"total_tokens": 349
}
}

View File

@ -5,20 +5,33 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"New York, NY\"}}", "content": null,
"name": null, "name": null,
"role": "assistant" "role": "assistant",
} "tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
},
"usage": null
} }
], ],
"created": 1708626030, "created": 1708957017,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.2-native", "system_fingerprint": "1.4.2-native",
"usage": { "usage": {
"completion_tokens": 21, "completion_tokens": 21,
"prompt_tokens": 189, "prompt_tokens": 184,
"total_tokens": 210 "total_tokens": 205
} }
} }

View File

@ -71,7 +71,6 @@ tools = [
] ]
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_no_tools( async def test_flash_llama_grammar_no_tools(
@ -99,7 +98,6 @@ async def test_flash_llama_grammar_no_tools(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -119,11 +117,59 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
}, },
], ],
) )
assert len(response.choices[0].message.content) == 78 assert response.choices[0].message.content == None
assert ( assert response.choices[0].message.tool_calls == {
response.choices[0].message.content "function": {
== """{"function":{"format": "celsius", "location": "New York, NY", "num_days": 14}}""" "description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2,
},
},
"id": 0,
"type": "function",
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
) )
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2,
},
},
"id": 0,
"type": "function",
}
assert response == response_snapshot assert response == response_snapshot
@ -149,9 +195,14 @@ async def test_flash_llama_grammar_tools_choice(
}, },
], ],
) )
assert len(response.choices[0].message.content) == 62 assert response.choices[0].message.content == None
assert ( assert response.choices[0].message.tool_calls == {
response.choices[0].message.content "id": 0,
== """{"function":{"format": "celsius", "location": "New York, NY"}}""" "type": "function",
) "function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
},
}
assert response == response_snapshot assert response == response_snapshot

View File

@ -581,6 +581,7 @@ mod deserialize_tool_choice {
Err(de::Error::custom("function key not found in tool choice")) Err(de::Error::custom("function key not found in tool choice"))
} }
} }
Value::Null => Ok(Some(ToolType::OneOf)),
_ => Err(de::Error::custom("invalid token format")), _ => Err(de::Error::custom("invalid token format")),
} }
} }