mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support streaming and improve docs
This commit is contained in:
parent
7c04b6d664
commit
0fc7237380
@ -3,7 +3,7 @@ import requests
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||
|
||||
from text_generation.types import (
|
||||
StreamResponse,
|
||||
@ -12,6 +12,7 @@ from text_generation.types import (
|
||||
Parameters,
|
||||
Grammar,
|
||||
ChatRequest,
|
||||
ChatCompletionChunk,
|
||||
ChatComplete,
|
||||
Message,
|
||||
Tool,
|
||||
@ -134,18 +135,42 @@ class Client:
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return ChatComplete(**payload)
|
||||
else:
|
||||
return self._chat_stream_response(request)
|
||||
|
||||
def _chat_stream_response(self, request):
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return ChatComplete(**payload)
|
||||
# iterate and print stream
|
||||
for byte_payload in resp.iter_lines():
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -417,7 +442,7 @@ class AsyncClient:
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
):
|
||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
||||
@ -472,6 +497,12 @@ class AsyncClient:
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
return await self._chat_single_response(request)
|
||||
else:
|
||||
return self._chat_stream_response(request)
|
||||
|
||||
async def _chat_single_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
@ -483,6 +514,25 @@ class AsyncClient:
|
||||
raise parse_error(resp.status, payload)
|
||||
return ChatComplete(**payload)
|
||||
|
||||
async def _chat_stream_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||
) as resp:
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -59,6 +59,40 @@ class ChatCompletionComplete(BaseModel):
|
||||
usage: Any
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChoiceDeltaToolCall(BaseModel):
|
||||
index: int
|
||||
id: str
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
delta: ChoiceDelta
|
||||
logprobs: Optional[dict] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Guidance
|
||||
|
||||
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version `1.4.3`. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
|
||||
Text Generation Inference (TGI) now supports [Grammar and Constraints](#grammar-and-constraints) and [Tools and Functions](#tools-and-functions) to help developer guide the LLM's responses and enhance its capabilities.
|
||||
|
||||
Whether you're a developer, a data scientist, or just a curious mind, we've made it super easy (and fun!) to start integrating advanced text generation capabilities into your applications.
|
||||
These feature is available starting from version `1.4.3`. These features are accessible via the text-generation-client library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
|
||||
### Quick Start
|
||||
|
||||
@ -214,7 +214,6 @@ if __name__ == "__main__":
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Tools and Functions 🛠️
|
||||
|
||||
### The Tools Parameter
|
||||
|
@ -14,8 +14,8 @@
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "San Francisco",
|
||||
"num_days": 2
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
@ -25,14 +25,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1708957016,
|
||||
"created": 1709079417,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 36,
|
||||
"prompt_tokens": 313,
|
||||
"total_tokens": 349
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
}
|
||||
}
|
||||
|
@ -14,8 +14,8 @@
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "San Francisco",
|
||||
"num_days": 2
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
@ -25,14 +25,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1708957016,
|
||||
"created": 1709079492,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 36,
|
||||
"prompt_tokens": 313,
|
||||
"total_tokens": 349
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
}
|
||||
}
|
||||
|
@ -24,14 +24,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1708957017,
|
||||
"created": 1709079493,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 184,
|
||||
"total_tokens": 205
|
||||
"prompt_tokens": 187,
|
||||
"total_tokens": 208
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,27 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "</s>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 20,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"finish_reason": "eos_token",
|
||||
"index": 20,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1709087088,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native"
|
||||
}
|
@ -124,8 +124,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "San Francisco",
|
||||
"num_days": 2,
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
@ -163,8 +163,8 @@ async def test_flash_llama_grammar_tools_auto(
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "San Francisco",
|
||||
"num_days": 2,
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
@ -206,3 +206,36 @@ async def test_flash_llama_grammar_tools_choice(
|
||||
},
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Paris, France?",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
async for response in responses:
|
||||
print(response)
|
||||
count += 1
|
||||
|
||||
assert count == 20
|
||||
assert response == response_snapshot
|
||||
|
@ -415,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
// default to None
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<DeltaToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct DeltaToolCall {
|
||||
pub index: u32,
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct Function {
|
||||
pub name: Option<String>,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl ChatCompletionChunk {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
delta: String,
|
||||
delta: Option<String>,
|
||||
tool_calls: Option<Vec<String>>,
|
||||
created: u64,
|
||||
index: u32,
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
@ -440,6 +460,15 @@ impl ChatCompletionChunk {
|
||||
delta: ChatCompletionDelta {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
||||
index,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tc[0].to_string(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
logprobs,
|
||||
finish_reason,
|
||||
@ -626,8 +655,8 @@ where
|
||||
state.end()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Function {
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
@ -640,7 +669,7 @@ pub(crate) struct Tool {
|
||||
#[schema(example = "function")]
|
||||
pub r#type: String,
|
||||
// Grab the tool as generic JSON for debugging purposes.
|
||||
pub function: Function,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@ -651,11 +680,11 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
add_generation_prompt: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct ToolCall {
|
||||
pub id: u32,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
|
@ -10,7 +10,7 @@ use crate::{
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -699,11 +699,19 @@ async fn chat_completions(
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
(Some(stream_token.token.text), None)
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
stream_token.token.text,
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
stream_token.index,
|
||||
logprobs,
|
||||
@ -756,7 +764,7 @@ async fn chat_completions(
|
||||
let tool_call = Some(ToolCall {
|
||||
id: 0,
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").map_or_else(
|
||||
|
Loading…
Reference in New Issue
Block a user