mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support streaming and improve docs
This commit is contained in:
parent
7c04b6d664
commit
0fc7237380
@ -3,7 +3,7 @@ import requests
|
|||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||||
|
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
@ -12,6 +12,7 @@ from text_generation.types import (
|
|||||||
Parameters,
|
Parameters,
|
||||||
Grammar,
|
Grammar,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
|
ChatCompletionChunk,
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
Message,
|
Message,
|
||||||
Tool,
|
Tool,
|
||||||
@ -134,7 +135,7 @@ class Client:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
if not stream:
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
f"{self.base_url}/v1/chat/completions",
|
f"{self.base_url}/v1/chat/completions",
|
||||||
json=request.dict(),
|
json=request.dict(),
|
||||||
@ -146,6 +147,30 @@ class Client:
|
|||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
raise parse_error(resp.status_code, payload)
|
raise parse_error(resp.status_code, payload)
|
||||||
return ChatComplete(**payload)
|
return ChatComplete(**payload)
|
||||||
|
else:
|
||||||
|
return self._chat_stream_response(request)
|
||||||
|
|
||||||
|
def _chat_stream_response(self, request):
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# iterate and print stream
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = ChatCompletionChunk(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -417,7 +442,7 @@ class AsyncClient:
|
|||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[List[Tool]] = None,
|
||||||
tool_choice: Optional[str] = None,
|
tool_choice: Optional[str] = None,
|
||||||
):
|
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||||
"""
|
"""
|
||||||
Given a list of messages, generate a response asynchronously
|
Given a list of messages, generate a response asynchronously
|
||||||
|
|
||||||
@ -472,6 +497,12 @@ class AsyncClient:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
if not stream:
|
||||||
|
return await self._chat_single_response(request)
|
||||||
|
else:
|
||||||
|
return self._chat_stream_response(request)
|
||||||
|
|
||||||
|
async def _chat_single_response(self, request):
|
||||||
async with ClientSession(
|
async with ClientSession(
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
) as session:
|
) as session:
|
||||||
@ -483,6 +514,25 @@ class AsyncClient:
|
|||||||
raise parse_error(resp.status, payload)
|
raise parse_error(resp.status, payload)
|
||||||
return ChatComplete(**payload)
|
return ChatComplete(**payload)
|
||||||
|
|
||||||
|
async def _chat_stream_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = ChatCompletionChunk(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -59,6 +59,40 @@ class ChatCompletionComplete(BaseModel):
|
|||||||
usage: Any
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: Optional[str]
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDeltaToolCall(BaseModel):
|
||||||
|
index: int
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDelta(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: ChoiceDelta
|
||||||
|
logprobs: Optional[dict] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choice]
|
||||||
|
|
||||||
|
|
||||||
class ChatComplete(BaseModel):
|
class ChatComplete(BaseModel):
|
||||||
# Chat completion details
|
# Chat completion details
|
||||||
id: str
|
id: str
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Guidance
|
# Guidance
|
||||||
|
|
||||||
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version `1.4.3`. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
|
Text Generation Inference (TGI) now supports [Grammar and Constraints](#grammar-and-constraints) and [Tools and Functions](#tools-and-functions) to help developer guide the LLM's responses and enhance its capabilities.
|
||||||
|
|
||||||
Whether you're a developer, a data scientist, or just a curious mind, we've made it super easy (and fun!) to start integrating advanced text generation capabilities into your applications.
|
These feature is available starting from version `1.4.3`. These features are accessible via the text-generation-client library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
@ -214,7 +214,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Tools and Functions 🛠️
|
## Tools and Functions 🛠️
|
||||||
|
|
||||||
### The Tools Parameter
|
### The Tools Parameter
|
||||||
|
@ -14,8 +14,8 @@
|
|||||||
"name": "tools",
|
"name": "tools",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "San Francisco",
|
"location": "New York, NY",
|
||||||
"num_days": 2
|
"num_days": 14
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
@ -25,14 +25,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708957016,
|
"created": 1709079417,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "1.4.2-native",
|
"system_fingerprint": "1.4.2-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 36,
|
"completion_tokens": 29,
|
||||||
"prompt_tokens": 313,
|
"prompt_tokens": 316,
|
||||||
"total_tokens": 349
|
"total_tokens": 345
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,8 +14,8 @@
|
|||||||
"name": "tools",
|
"name": "tools",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "San Francisco",
|
"location": "New York, NY",
|
||||||
"num_days": 2
|
"num_days": 14
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
@ -25,14 +25,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708957016,
|
"created": 1709079492,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "1.4.2-native",
|
"system_fingerprint": "1.4.2-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 36,
|
"completion_tokens": 29,
|
||||||
"prompt_tokens": 313,
|
"prompt_tokens": 316,
|
||||||
"total_tokens": 349
|
"total_tokens": 345
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,14 +24,14 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1708957017,
|
"created": 1709079493,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "1.4.2-native",
|
"system_fingerprint": "1.4.2-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 21,
|
"completion_tokens": 21,
|
||||||
"prompt_tokens": 184,
|
"prompt_tokens": 187,
|
||||||
"total_tokens": 205
|
"total_tokens": 208
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"arguments": "</s>",
|
||||||
|
"name": null
|
||||||
|
},
|
||||||
|
"id": "",
|
||||||
|
"index": 20,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 20,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1709087088,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native"
|
||||||
|
}
|
@ -124,8 +124,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
"name": "tools",
|
"name": "tools",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "San Francisco",
|
"location": "New York, NY",
|
||||||
"num_days": 2,
|
"num_days": 14,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
@ -163,8 +163,8 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
"name": "tools",
|
"name": "tools",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "San Francisco",
|
"location": "New York, NY",
|
||||||
"num_days": 2,
|
"num_days": 14,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": 0,
|
||||||
@ -206,3 +206,36 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="get_current_weather",
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Paris, France?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
async for response in responses:
|
||||||
|
print(response)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
assert count == 20
|
||||||
|
assert response == response_snapshot
|
||||||
|
@ -415,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: String,
|
pub content: Option<String>,
|
||||||
|
// default to None
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<DeltaToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
|
pub(crate) struct DeltaToolCall {
|
||||||
|
pub index: u32,
|
||||||
|
pub id: String,
|
||||||
|
pub r#type: String,
|
||||||
|
pub function: Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
|
pub(crate) struct Function {
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
impl ChatCompletionChunk {
|
impl ChatCompletionChunk {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
delta: String,
|
delta: Option<String>,
|
||||||
|
tool_calls: Option<Vec<String>>,
|
||||||
created: u64,
|
created: u64,
|
||||||
index: u32,
|
index: u32,
|
||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
@ -440,6 +460,15 @@ impl ChatCompletionChunk {
|
|||||||
delta: ChatCompletionDelta {
|
delta: ChatCompletionDelta {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: delta,
|
content: delta,
|
||||||
|
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
||||||
|
index,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: tc[0].to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
@ -626,8 +655,8 @@ where
|
|||||||
state.end()
|
state.end()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||||
pub(crate) struct Function {
|
pub(crate) struct FunctionDefinition {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@ -640,7 +669,7 @@ pub(crate) struct Tool {
|
|||||||
#[schema(example = "function")]
|
#[schema(example = "function")]
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
// Grab the tool as generic JSON for debugging purposes.
|
// Grab the tool as generic JSON for debugging purposes.
|
||||||
pub function: Function,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
@ -651,11 +680,11 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||||
pub(crate) struct ToolCall {
|
pub(crate) struct ToolCall {
|
||||||
pub id: u32,
|
pub id: u32,
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
pub function: Function,
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
@ -10,7 +10,7 @@ use crate::{
|
|||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -699,11 +699,19 @@ async fn chat_completions(
|
|||||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// replace the content with the tool calls if grammar is present
|
||||||
|
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||||
|
(None, Some(vec![stream_token.token.text]))
|
||||||
|
} else {
|
||||||
|
(Some(stream_token.token.text), None)
|
||||||
|
};
|
||||||
|
|
||||||
event
|
event
|
||||||
.json_data(ChatCompletionChunk::new(
|
.json_data(ChatCompletionChunk::new(
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
stream_token.token.text,
|
content,
|
||||||
|
tool_calls,
|
||||||
current_time,
|
current_time,
|
||||||
stream_token.index,
|
stream_token.index,
|
||||||
logprobs,
|
logprobs,
|
||||||
@ -756,7 +764,7 @@ async fn chat_completions(
|
|||||||
let tool_call = Some(ToolCall {
|
let tool_call = Some(ToolCall {
|
||||||
id: 0,
|
id: 0,
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: Function {
|
function: FunctionDefinition {
|
||||||
description: None,
|
description: None,
|
||||||
name: "tools".to_string(),
|
name: "tools".to_string(),
|
||||||
parameters: gen_text_value.get("function").map_or_else(
|
parameters: gen_text_value.get("function").map_or_else(
|
||||||
|
Loading…
Reference in New Issue
Block a user