feat: support streaming and improve docs

This commit is contained in:
drbh 2024-02-28 02:32:02 +00:00
parent 7c04b6d664
commit 0fc7237380
10 changed files with 218 additions and 38 deletions

View File

@ -3,7 +3,7 @@ import requests
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation.types import ( from text_generation.types import (
StreamResponse, StreamResponse,
@ -12,6 +12,7 @@ from text_generation.types import (
Parameters, Parameters,
Grammar, Grammar,
ChatRequest, ChatRequest,
ChatCompletionChunk,
ChatComplete, ChatComplete,
Message, Message,
Tool, Tool,
@ -134,18 +135,42 @@ class Client:
tools=tools, tools=tools,
tool_choice=tool_choice, tool_choice=tool_choice,
) )
if not stream:
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
else:
return self._chat_stream_response(request)
def _chat_stream_response(self, request):
resp = requests.post( resp = requests.post(
f"{self.base_url}/v1/chat/completions", f"{self.base_url}/v1/chat/completions",
json=request.dict(), json=request.dict(),
headers=self.headers, headers=self.headers,
cookies=self.cookies, cookies=self.cookies,
timeout=self.timeout, timeout=self.timeout,
stream=True,
) )
payload = resp.json() # iterate and print stream
if resp.status_code != 200: for byte_payload in resp.iter_lines():
raise parse_error(resp.status_code, payload) if byte_payload == b"\n":
return ChatComplete(**payload) continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def generate( def generate(
self, self,
@ -417,7 +442,7 @@ class AsyncClient:
top_p: Optional[float] = None, top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None, tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None, tool_choice: Optional[str] = None,
): ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
""" """
Given a list of messages, generate a response asynchronously Given a list of messages, generate a response asynchronously
@ -472,6 +497,12 @@ class AsyncClient:
tools=tools, tools=tools,
tool_choice=tool_choice, tool_choice=tool_choice,
) )
if not stream:
return await self._chat_single_response(request)
else:
return self._chat_stream_response(request)
async def _chat_single_response(self, request):
async with ClientSession( async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session: ) as session:
@ -483,6 +514,25 @@ class AsyncClient:
raise parse_error(resp.status, payload) raise parse_error(resp.status, payload)
return ChatComplete(**payload) return ChatComplete(**payload)
async def _chat_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def generate( async def generate(
self, self,
prompt: str, prompt: str,

View File

@ -59,6 +59,40 @@ class ChatCompletionComplete(BaseModel):
usage: Any usage: Any
class Function(BaseModel):
name: Optional[str]
arguments: str
class ChoiceDeltaToolCall(BaseModel):
index: int
id: str
type: str
function: Function
class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
tool_calls: Optional[ChoiceDeltaToolCall]
class Choice(BaseModel):
index: int
delta: ChoiceDelta
logprobs: Optional[dict] = None
finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class ChatComplete(BaseModel): class ChatComplete(BaseModel):
# Chat completion details # Chat completion details
id: str id: str

View File

@ -1,8 +1,8 @@
# Guidance # Guidance
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version `1.4.3`. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility. Text Generation Inference (TGI) now supports [Grammar and Constraints](#grammar-and-constraints) and [Tools and Functions](#tools-and-functions) to help developer guide the LLM's responses and enhance its capabilities.
Whether you're a developer, a data scientist, or just a curious mind, we've made it super easy (and fun!) to start integrating advanced text generation capabilities into your applications. These feature is available starting from version `1.4.3`. These features are accessible via the text-generation-client library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
### Quick Start ### Quick Start
@ -214,7 +214,6 @@ if __name__ == "__main__":
``` ```
## Tools and Functions 🛠️ ## Tools and Functions 🛠️
### The Tools Parameter ### The Tools Parameter

View File

@ -14,8 +14,8 @@
"name": "tools", "name": "tools",
"parameters": { "parameters": {
"format": "celsius", "format": "celsius",
"location": "San Francisco", "location": "New York, NY",
"num_days": 2 "num_days": 14
} }
}, },
"id": 0, "id": 0,
@ -25,14 +25,14 @@
"usage": null "usage": null
} }
], ],
"created": 1708957016, "created": 1709079417,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.2-native", "system_fingerprint": "1.4.2-native",
"usage": { "usage": {
"completion_tokens": 36, "completion_tokens": 29,
"prompt_tokens": 313, "prompt_tokens": 316,
"total_tokens": 349 "total_tokens": 345
} }
} }

View File

@ -14,8 +14,8 @@
"name": "tools", "name": "tools",
"parameters": { "parameters": {
"format": "celsius", "format": "celsius",
"location": "San Francisco", "location": "New York, NY",
"num_days": 2 "num_days": 14
} }
}, },
"id": 0, "id": 0,
@ -25,14 +25,14 @@
"usage": null "usage": null
} }
], ],
"created": 1708957016, "created": 1709079492,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.2-native", "system_fingerprint": "1.4.2-native",
"usage": { "usage": {
"completion_tokens": 36, "completion_tokens": 29,
"prompt_tokens": 313, "prompt_tokens": 316,
"total_tokens": 349 "total_tokens": 345
} }
} }

View File

@ -24,14 +24,14 @@
"usage": null "usage": null
} }
], ],
"created": 1708957017, "created": 1709079493,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "1.4.2-native", "system_fingerprint": "1.4.2-native",
"usage": { "usage": {
"completion_tokens": 21, "completion_tokens": 21,
"prompt_tokens": 184, "prompt_tokens": 187,
"total_tokens": 205 "total_tokens": 208
} }
} }

View File

@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "</s>",
"name": null
},
"id": "",
"index": 20,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"logprobs": null
}
],
"created": 1709087088,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native"
}

View File

@ -124,8 +124,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"name": "tools", "name": "tools",
"parameters": { "parameters": {
"format": "celsius", "format": "celsius",
"location": "San Francisco", "location": "New York, NY",
"num_days": 2, "num_days": 14,
}, },
}, },
"id": 0, "id": 0,
@ -163,8 +163,8 @@ async def test_flash_llama_grammar_tools_auto(
"name": "tools", "name": "tools",
"parameters": { "parameters": {
"format": "celsius", "format": "celsius",
"location": "San Francisco", "location": "New York, NY",
"num_days": 2, "num_days": 14,
}, },
}, },
"id": 0, "id": 0,
@ -206,3 +206,36 @@ async def test_flash_llama_grammar_tools_choice(
}, },
} }
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Paris, France?",
},
],
stream=True,
)
count = 0
async for response in responses:
print(response)
count += 1
assert count == 20
assert response == response_snapshot

View File

@ -415,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub content: String, pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
}
#[allow(clippy::too_many_arguments)]
impl ChatCompletionChunk { impl ChatCompletionChunk {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
delta: String, delta: Option<String>,
tool_calls: Option<Vec<String>>,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
@ -440,6 +460,15 @@ impl ChatCompletionChunk {
delta: ChatCompletionDelta { delta: ChatCompletionDelta {
role: "assistant".to_string(), role: "assistant".to_string(),
content: delta, content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
}, },
logprobs, logprobs,
finish_reason, finish_reason,
@ -626,8 +655,8 @@ where
state.end() state.end()
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Function { pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
@ -640,7 +669,7 @@ pub(crate) struct Tool {
#[schema(example = "function")] #[schema(example = "function")]
pub r#type: String, pub r#type: String,
// Grab the tool as generic JSON for debugging purposes. // Grab the tool as generic JSON for debugging purposes.
pub function: Function, pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -651,11 +680,11 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool, add_generation_prompt: bool,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall { pub(crate) struct ToolCall {
pub id: u32, pub id: u32,
pub r#type: String, pub r#type: String,
pub function: Function, pub function: FunctionDefinition,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]

View File

@ -10,7 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
}; };
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -699,11 +699,19 @@ async fn chat_completions(
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
}); });
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text]))
} else {
(Some(stream_token.token.text), None)
};
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
system_fingerprint.clone(), system_fingerprint.clone(),
stream_token.token.text, content,
tool_calls,
current_time, current_time,
stream_token.index, stream_token.index,
logprobs, logprobs,
@ -756,7 +764,7 @@ async fn chat_completions(
let tool_call = Some(ToolCall { let tool_call = Some(ToolCall {
id: 0, id: 0,
r#type: "function".to_string(), r#type: "function".to_string(),
function: Function { function: FunctionDefinition {
description: None, description: None,
name: "tools".to_string(), name: "tools".to_string(),
parameters: gen_text_value.get("function").map_or_else( parameters: gen_text_value.get("function").map_or_else(