mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: minimal tool support and chat client
This commit is contained in:
parent
0f500f6d14
commit
c8f2081171
@ -11,6 +11,10 @@ from text_generation.types import (
|
|||||||
Request,
|
Request,
|
||||||
Parameters,
|
Parameters,
|
||||||
Grammar,
|
Grammar,
|
||||||
|
ChatRequest,
|
||||||
|
ChatComplete,
|
||||||
|
Message,
|
||||||
|
Tool,
|
||||||
)
|
)
|
||||||
from text_generation.errors import parse_error
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
@ -59,6 +63,52 @@ class Client:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[List[float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
tools: Optional[List[Tool]] = None,
|
||||||
|
):
|
||||||
|
""" """
|
||||||
|
request = ChatRequest(
|
||||||
|
model="tgi",
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
stream=stream,
|
||||||
|
seed=seed,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return ChatComplete(**payload)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -313,6 +363,52 @@ class AsyncClient:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout * 60)
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[List[float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
tools: Optional[List[Tool]] = None,
|
||||||
|
):
|
||||||
|
""" """
|
||||||
|
print("chat")
|
||||||
|
request = ChatRequest(
|
||||||
|
model="tgi",
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
stream=stream,
|
||||||
|
seed=seed,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
print(self.base_url)
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return ChatComplete(**payload)
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union, Any
|
||||||
|
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
@ -19,6 +19,75 @@ class Grammar(BaseModel):
|
|||||||
value: Union[str, dict]
|
value: Union[str, dict]
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
# Role of the message sender
|
||||||
|
role: str
|
||||||
|
# Content of the message
|
||||||
|
content: str
|
||||||
|
# Optional name of the message sender
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
# Type of the tool
|
||||||
|
type: str
|
||||||
|
# Function details of the tool
|
||||||
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionComplete(BaseModel):
|
||||||
|
# Index of the chat completion
|
||||||
|
index: int
|
||||||
|
# Message associated with the chat completion
|
||||||
|
message: Message
|
||||||
|
# Log probabilities for the chat completion
|
||||||
|
logprobs: Optional[Any]
|
||||||
|
# Reason for completion
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatComplete(BaseModel):
|
||||||
|
# Chat completion details
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[ChatCompletionComplete]
|
||||||
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
# Model identifier
|
||||||
|
model: str
|
||||||
|
# List of messages in the conversation
|
||||||
|
messages: List[Message]
|
||||||
|
# Penalty for frequency of new tokens
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
# Bias values for token selection
|
||||||
|
logit_bias: Optional[List[float]] = None
|
||||||
|
# Whether to return log probabilities
|
||||||
|
logprobs: Optional[bool] = None
|
||||||
|
# Number of most likely tokens to return at each position
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
# Maximum number of tokens to generate
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
# Number of chat completion choices to generate
|
||||||
|
n: Optional[int] = None
|
||||||
|
# Penalty for presence of new tokens
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
# Flag to indicate streaming response
|
||||||
|
stream: bool = False
|
||||||
|
# Random sampling seed
|
||||||
|
seed: Optional[int] = None
|
||||||
|
# Sampling temperature
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
# Top-p value for nucleus sampling
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
# List of tools to be used
|
||||||
|
tools: Optional[List[Tool]] = None
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
|
@ -23,6 +23,7 @@ from text_generation.types import (
|
|||||||
Token,
|
Token,
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Grammar,
|
Grammar,
|
||||||
|
ChatComplete,
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
@ -59,7 +60,8 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
def convert_data(data):
|
def convert_data(data):
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
|
if isinstance(data, Dict) and "choices" in data:
|
||||||
|
return ChatComplete(**data)
|
||||||
if isinstance(data, Dict):
|
if isinstance(data, Dict):
|
||||||
return Response(**data)
|
return Response(**data)
|
||||||
if isinstance(data, List):
|
if isinstance(data, List):
|
||||||
@ -144,6 +146,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||||
|
return (
|
||||||
|
response.choices[0].message.content == other.choices[0].message.content
|
||||||
|
)
|
||||||
|
|
||||||
def eq_response(response: Response, other: Response) -> bool:
|
def eq_response(response: Response, other: Response) -> bool:
|
||||||
return response.generated_text == other.generated_text and eq_details(
|
return response.generated_text == other.generated_text and eq_details(
|
||||||
response.details, other.details
|
response.details, other.details
|
||||||
@ -157,6 +164,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
if not isinstance(snapshot_data, List):
|
if not isinstance(snapshot_data, List):
|
||||||
snapshot_data = [snapshot_data]
|
snapshot_data = [snapshot_data]
|
||||||
|
|
||||||
|
if isinstance(serialized_data[0], ChatComplete):
|
||||||
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
|
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
|
)
|
||||||
|
|
||||||
return len(snapshot_data) == len(serialized_data) and all(
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "As for the weather in Brooklyn, New York, it can vary depending on the location within the borough. According to climatereporter.com, the average temperature in August is 73 degrees Fahrenheit (23 degrees Celsius), while the humidity is 62%. In the winter (December to February), the temperature averages between 20 and 45 degrees Fahrenheit (6 to 8 degrees Celsius), with significant",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1708103426,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 100,
|
||||||
|
"prompt_tokens": 82,
|
||||||
|
"total_tokens": 182
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\"function\":{\"format\": \"celsius\", \"location\": \"Brooklyn, NYC\", \"num_days\": 1255}}",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1708103426,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 33,
|
||||||
|
"prompt_tokens": 321,
|
||||||
|
"total_tokens": 354
|
||||||
|
}
|
||||||
|
}
|
126
integration-tests/models/test_tools_llama.py
Normal file
126
integration-tests/models/test_tools_llama.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_grammar_tools_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
||||||
|
await flash_llama_grammar_tools_handle.health(300)
|
||||||
|
return flash_llama_grammar_tools_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# tools to be used in the following tests
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_n_day_weather_forecast",
|
||||||
|
"description": "Get an N-day weather forecast",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
"num_days": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of days to forecast",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format", "num_days"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_no_tools_regex(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "As for the weather in Brooklyn, New York, it can vary depending on the location within the borough. According to climatereporter.com, the average temperature in August is 73 degrees Fahrenheit (23 degrees Celsius), while the humidity is 62%. In the winter (December to February), the temperature averages between 20 and 45 degrees Fahrenheit (6 to 8 degrees Celsius), with significant"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_regex(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=0,
|
||||||
|
tools=tools,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert len(response.choices[0].message.content) == 81
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== """{"function":{"format": "celsius", "location": "Brooklyn, NYC", "num_days": 1255}}"""
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
@ -55,6 +55,7 @@ enum GrammarType {
|
|||||||
GRAMMAR_TYPE_NONE = 0;
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
GRAMMAR_TYPE_JSON = 1;
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
GRAMMAR_TYPE_REGEX = 2;
|
GRAMMAR_TYPE_REGEX = 2;
|
||||||
|
GRAMMAR_TYPE_OPTIONAL_JSON = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message NextTokenChooserParameters {
|
message NextTokenChooserParameters {
|
||||||
|
@ -513,7 +513,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||||
schema = build_regex_from_object(schema)
|
schema = build_regex_from_object(schema)
|
||||||
elif grammar_type == GrammarType.OPTIONAL_GRAMMAR_TYPE_REGEX:
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_OPTIONAL_JSON:
|
||||||
# TODO: use a better method to handle optional grammars
|
# TODO: use a better method to handle optional grammars
|
||||||
schema = f"({build_regex_from_object(schema)})|.*"
|
schema = f"({build_regex_from_object(schema)})|.*"
|
||||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||||
|
Loading…
Reference in New Issue
Block a user