feat: minimal tool support and chat client

This commit is contained in:
drbh 2024-02-16 17:18:21 +00:00
parent 0f500f6d14
commit c8f2081171
8 changed files with 355 additions and 3 deletions

View File

@ -11,6 +11,10 @@ from text_generation.types import (
Request, Request,
Parameters, Parameters,
Grammar, Grammar,
ChatRequest,
ChatComplete,
Message,
Tool,
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
@ -59,6 +63,52 @@ class Client:
self.cookies = cookies self.cookies = cookies
self.timeout = timeout self.timeout = timeout
def chat(
self,
messages: List[Message],
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
stream: bool = False,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
):
""" """
request = ChatRequest(
model="tgi",
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stream=stream,
seed=seed,
temperature=temperature,
top_p=top_p,
tools=tools,
)
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
def generate( def generate(
self, self,
prompt: str, prompt: str,
@ -313,6 +363,52 @@ class AsyncClient:
self.cookies = cookies self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60) self.timeout = ClientTimeout(timeout * 60)
async def chat(
self,
messages: List[Message],
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
stream: bool = False,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
):
""" """
print("chat")
request = ChatRequest(
model="tgi",
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stream=stream,
seed=seed,
temperature=temperature,
top_p=top_p,
tools=tools,
)
print(self.base_url)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return ChatComplete(**payload)
async def generate( async def generate(
self, self,
prompt: str, prompt: str,

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List, Union from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -19,6 +19,75 @@ class Grammar(BaseModel):
value: Union[str, dict] value: Union[str, dict]
class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: str
# Optional name of the message sender
name: Optional[str] = None
class Tool(BaseModel):
# Type of the tool
type: str
# Function details of the tool
function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatRequest(BaseModel):
# Model identifier
model: str
# List of messages in the conversation
messages: List[Message]
# Penalty for frequency of new tokens
frequency_penalty: Optional[float] = None
# Bias values for token selection
logit_bias: Optional[List[float]] = None
# Whether to return log probabilities
logprobs: Optional[bool] = None
# Number of most likely tokens to return at each position
top_logprobs: Optional[int] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Number of chat completion choices to generate
n: Optional[int] = None
# Penalty for presence of new tokens
presence_penalty: Optional[float] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# List of tools to be used
tools: Optional[List[Tool]] = None
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False

View File

@ -23,6 +23,7 @@ from text_generation.types import (
Token, Token,
BestOfSequence, BestOfSequence,
Grammar, Grammar,
ChatComplete,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -59,7 +60,8 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool: ) -> bool:
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
if isinstance(data, List): if isinstance(data, List):
@ -144,6 +146,11 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return (
response.choices[0].message.content == other.choices[0].message.content
)
def eq_response(response: Response, other: Response) -> bool: def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details( return response.generated_text == other.generated_text and eq_details(
response.details, other.details response.details, other.details
@ -157,6 +164,11 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List): if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data] snapshot_data = [snapshot_data]
if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
) )

View File

@ -0,0 +1,24 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "As for the weather in Brooklyn, New York, it can vary depending on the location within the borough. According to climatereporter.com, the average temperature in August is 73 degrees Fahrenheit (23 degrees Celsius), while the humidity is 62%. In the winter (December to February), the temperature averages between 20 and 45 degrees Fahrenheit (6 to 8 degrees Celsius), with significant",
"name": null,
"role": "assistant"
}
}
],
"created": 1708103426,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.0-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 82,
"total_tokens": 182
}
}

View File

@ -0,0 +1,24 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}",
"name": null,
"role": "assistant"
}
}
],
"created": 1708103426,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.0-native",
"usage": {
"completion_tokens": 33,
"prompt_tokens": 321,
"total_tokens": 354
}
}

View File

@ -0,0 +1,126 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
await flash_llama_grammar_tools_handle.health(300)
return flash_llama_grammar_tools_handle.client
# tools to be used in the following tests
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
},
{
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast",
},
},
"required": ["location", "format", "num_days"],
},
},
},
]
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_no_tools_regex(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=0,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As for the weather in Brooklyn, New York, it can vary depending on the location within the borough. According to climatereporter.com, the average temperature in August is 73 degrees Fahrenheit (23 degrees Celsius), while the humidity is 62%. In the winter (December to February), the temperature averages between 20 and 45 degrees Fahrenheit (6 to 8 degrees Celsius), with significant"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_regex(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=0,
tools=tools,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert len(response.choices[0].message.content) == 81
assert (
response.choices[0].message.content
== """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}"""
)
assert response == response_snapshot

View File

@ -55,6 +55,7 @@ enum GrammarType {
GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2; GRAMMAR_TYPE_REGEX = 2;
GRAMMAR_TYPE_OPTIONAL_JSON = 3;
} }
message NextTokenChooserParameters { message NextTokenChooserParameters {

View File

@ -513,7 +513,7 @@ class GrammarLogitProcessor(LogitsProcessor):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.OPTIONAL_GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON:
# TODO: use a better method to handle optional grammars # TODO: use a better method to handle optional grammars
schema = f"({build_regex_from_object(schema)})|.*" schema = f"({build_regex_from_object(schema)})|.*"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: