From c8f208117141e3f1746246664d25aa9a9cde9696 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 16 Feb 2024 17:18:21 +0000 Subject: [PATCH] feat: minimal tool support and chat client --- clients/python/text_generation/client.py | 96 +++++++++++++ clients/python/text_generation/types.py | 71 +++++++++- integration-tests/conftest.py | 14 +- ...st_flash_llama_grammar_no_tools_regex.json | 24 ++++ .../test_flash_llama_grammar_tools_regex.json | 24 ++++ integration-tests/models/test_tools_llama.py | 126 ++++++++++++++++++ proto/generate.proto | 1 + .../utils/logits_process.py | 2 +- 8 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools_regex.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_regex.json create mode 100644 integration-tests/models/test_tools_llama.py diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bbccbf1d..932b3e32 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -11,6 +11,10 @@ from text_generation.types import ( Request, Parameters, Grammar, + ChatRequest, + ChatComplete, + Message, + Tool, ) from text_generation.errors import parse_error @@ -59,6 +63,52 @@ class Client: self.cookies = cookies self.timeout = timeout + def chat( + self, + messages: List[Message], + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + ): + """ """ + request = ChatRequest( + model="tgi", + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + ) + + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return ChatComplete(**payload) + def generate( self, prompt: str, @@ -313,6 +363,52 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) + async def chat( + self, + messages: List[Message], + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + ): + """ """ + print("chat") + request = ChatRequest( + model="tgi", + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + ) + print(self.base_url) + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return ChatComplete(**payload) + async def generate( self, prompt: str, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 911114ee..1c6a1c47 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List, Union +from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -19,6 +19,75 @@ class Grammar(BaseModel): value: Union[str, dict] +class Message(BaseModel): + # Role of the message sender + role: str + # Content of the message + content: str + # Optional name of the message sender + name: Optional[str] = None + + +class Tool(BaseModel): + # Type of the tool + type: str + # Function details of the tool + function: dict + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatRequest(BaseModel): + # Model identifier + model: str + # List of messages in the conversation + messages: List[Message] + # Penalty for frequency of new tokens + frequency_penalty: Optional[float] = None + # Bias values for token selection + logit_bias: Optional[List[float]] = None + # Whether to return log probabilities + logprobs: Optional[bool] = None + # Number of most likely tokens to return at each position + top_logprobs: Optional[int] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Number of chat completion choices to generate + n: Optional[int] = None + # Penalty for presence of new tokens + presence_penalty: Optional[float] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # List of tools to be used + tools: Optional[List[Tool]] = None + + class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 80457bc2..a9645153 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -23,6 +23,7 @@ from text_generation.types import ( Token, BestOfSequence, Grammar, + ChatComplete, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) @@ -59,7 +60,8 @@ class ResponseComparator(JSONSnapshotExtension): ) -> bool: def convert_data(data): data = json.loads(data) - + if isinstance(data, Dict) and "choices" in data: + return ChatComplete(**data) if isinstance(data, Dict): return Response(**data) if isinstance(data, List): @@ -144,6 +146,11 @@ class ResponseComparator(JSONSnapshotExtension): ) ) + def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: + return ( + response.choices[0].message.content == other.choices[0].message.content + ) + def eq_response(response: Response, other: Response) -> bool: return response.generated_text == other.generated_text and eq_details( response.details, other.details @@ -157,6 +164,11 @@ class ResponseComparator(JSONSnapshotExtension): if not isinstance(snapshot_data, List): snapshot_data = [snapshot_data] + if isinstance(serialized_data[0], ChatComplete): + return len(snapshot_data) == len(serialized_data) and all( + [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + return len(snapshot_data) == len(serialized_data) and all( [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] ) diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools_regex.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools_regex.json new file mode 100644 index 00000000..0ff1630f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools_regex.json @@ -0,0 +1,24 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "As for the weather in Brooklyn, New York, it can vary depending on the location within the borough. According to climatereporter.com, the average temperature in August is 73 degrees Fahrenheit (23 degrees Celsius), while the humidity is 62%. In the winter (December to February), the temperature averages between 20 and 45 degrees Fahrenheit (6 to 8 degrees Celsius), with significant", + "name": null, + "role": "assistant" + } + } + ], + "created": 1708103426, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.0-native", + "usage": { + "completion_tokens": 100, + "prompt_tokens": 82, + "total_tokens": 182 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_regex.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_regex.json new file mode 100644 index 00000000..d3a868f2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_regex.json @@ -0,0 +1,24 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}", + "name": null, + "role": "assistant" + } + } + ], + "created": 1708103426, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.0-native", + "usage": { + "completion_tokens": 33, + "prompt_tokens": 321, + "total_tokens": 354 + } +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py new file mode 100644 index 00000000..d12ed648 --- /dev/null +++ b/integration-tests/models/test_tools_llama.py @@ -0,0 +1,126 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_tools_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle): + await flash_llama_grammar_tools_handle.health(300) + return flash_llama_grammar_tools_handle.client + + +# tools to be used in the following tests +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + }, +] + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_no_tools_regex( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=0, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + assert ( + response.choices[0].message.content + == "As for the weather in Brooklyn, New York, it can vary depending on the location within the borough. According to climatereporter.com, the average temperature in August is 73 degrees Fahrenheit (23 degrees Celsius), while the humidity is 62%. In the winter (December to February), the temperature averages between 20 and 45 degrees Fahrenheit (6 to 8 degrees Celsius), with significant" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_regex( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=0, + tools=tools, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert len(response.choices[0].message.content) == 81 + assert ( + response.choices[0].message.content + == """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}""" + ) + assert response == response_snapshot diff --git a/proto/generate.proto b/proto/generate.proto index 0490029f..1c252599 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -55,6 +55,7 @@ enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_REGEX = 2; + GRAMMAR_TYPE_OPTIONAL_JSON = 3; } message NextTokenChooserParameters { diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index c9a32ff7..c0ccd83f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -513,7 +513,7 @@ class GrammarLogitProcessor(LogitsProcessor): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: schema = build_regex_from_object(schema) - elif grammar_type == GrammarType.OPTIONAL_GRAMMAR_TYPE_REGEX: + elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON: # TODO: use a better method to handle optional grammars schema = f"({build_regex_from_object(schema)})|.*" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: