feat: support streaming and improve docs

This commit is contained in:
drbh 2024-02-28 02:32:02 +00:00
parent 7c04b6d664
commit 0fc7237380
10 changed files with 218 additions and 38 deletions

View File

@ -3,7 +3,7 @@ import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation.types import (
StreamResponse,
@ -12,6 +12,7 @@ from text_generation.types import (
Parameters,
Grammar,
ChatRequest,
ChatCompletionChunk,
ChatComplete,
Message,
Tool,
@ -134,18 +135,42 @@ class Client:
tools=tools,
tool_choice=tool_choice,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
else:
return self._chat_stream_response(request)
def _chat_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def generate(
self,
@ -417,7 +442,7 @@ class AsyncClient:
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
):
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
"""
Given a list of messages, generate a response asynchronously
@ -472,6 +497,12 @@ class AsyncClient:
tools=tools,
tool_choice=tool_choice,
)
if not stream:
return await self._chat_single_response(request)
else:
return self._chat_stream_response(request)
async def _chat_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
@ -483,6 +514,25 @@ class AsyncClient:
raise parse_error(resp.status, payload)
return ChatComplete(**payload)
async def _chat_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def generate(
self,
prompt: str,

View File

@ -59,6 +59,40 @@ class ChatCompletionComplete(BaseModel):
usage: Any
class Function(BaseModel):
name: Optional[str]
arguments: str
class ChoiceDeltaToolCall(BaseModel):
index: int
id: str
type: str
function: Function
class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
tool_calls: Optional[ChoiceDeltaToolCall]
class Choice(BaseModel):
index: int
delta: ChoiceDelta
logprobs: Optional[dict] = None
finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class ChatComplete(BaseModel):
# Chat completion details
id: str

View File

@ -1,8 +1,8 @@
# Guidance
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version `1.4.3`. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
Text Generation Inference (TGI) now supports [Grammar and Constraints](#grammar-and-constraints) and [Tools and Functions](#tools-and-functions) to help developer guide the LLM's responses and enhance its capabilities.
Whether you're a developer, a data scientist, or just a curious mind, we've made it super easy (and fun!) to start integrating advanced text generation capabilities into your applications.
These feature is available starting from version `1.4.3`. These features are accessible via the text-generation-client library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
### Quick Start
@ -214,7 +214,6 @@ if __name__ == "__main__":
```
## Tools and Functions 🛠️
### The Tools Parameter

View File

@ -14,8 +14,8 @@
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
@ -25,14 +25,14 @@
"usage": null
}
],
"created": 1708957016,
"created": 1709079417,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 313,
"total_tokens": 349
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}

View File

@ -14,8 +14,8 @@
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
@ -25,14 +25,14 @@
"usage": null
}
],
"created": 1708957016,
"created": 1709079492,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 313,
"total_tokens": 349
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}

View File

@ -24,14 +24,14 @@
"usage": null
}
],
"created": 1708957017,
"created": 1709079493,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 184,
"total_tokens": 205
"prompt_tokens": 187,
"total_tokens": 208
}
}

View File

@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "</s>",
"name": null
},
"id": "",
"index": 20,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"logprobs": null
}
],
"created": 1709087088,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native"
}

View File

@ -124,8 +124,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2,
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
@ -163,8 +163,8 @@ async def test_flash_llama_grammar_tools_auto(
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2,
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
@ -206,3 +206,36 @@ async def test_flash_llama_grammar_tools_choice(
},
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Paris, France?",
},
],
stream=True,
)
count = 0
async for response in responses:
print(response)
count += 1
assert count == 20
assert response == response_snapshot

View File

@ -415,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: String,
pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
}
#[allow(clippy::too_many_arguments)]
impl ChatCompletionChunk {
pub(crate) fn new(
model: String,
system_fingerprint: String,
delta: String,
delta: Option<String>,
tool_calls: Option<Vec<String>>,
created: u64,
index: u32,
logprobs: Option<ChatCompletionLogprobs>,
@ -440,6 +460,15 @@ impl ChatCompletionChunk {
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
},
logprobs,
finish_reason,
@ -626,8 +655,8 @@ where
state.end()
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Function {
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
@ -640,7 +669,7 @@ pub(crate) struct Tool {
#[schema(example = "function")]
pub r#type: String,
// Grab the tool as generic JSON for debugging purposes.
pub function: Function,
pub function: FunctionDefinition,
}
#[derive(Clone, Serialize, Deserialize)]
@ -651,11 +680,11 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall {
pub id: u32,
pub r#type: String,
pub function: Function,
pub function: FunctionDefinition,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]

View File

@ -10,7 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
};
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
@ -699,11 +699,19 @@ async fn chat_completions(
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text]))
} else {
(Some(stream_token.token.text), None)
};
event
.json_data(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
stream_token.token.text,
content,
tool_calls,
current_time,
stream_token.index,
logprobs,
@ -756,7 +764,7 @@ async fn chat_completions(
let tool_call = Some(ToolCall {
id: 0,
r#type: "function".to_string(),
function: Function {
function: FunctionDefinition {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").map_or_else(