From 0fc7237380831eb2fcd6c622b9e210708930386c Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 28 Feb 2024 02:32:02 +0000 Subject: [PATCH] feat: support streaming and improve docs --- clients/python/text_generation/client.py | 62 +++++++++++++++++-- clients/python/text_generation/types.py | 34 ++++++++++ docs/source/guidance.md | 5 +- .../test_flash_llama_grammar_tools.json | 12 ++-- .../test_flash_llama_grammar_tools_auto.json | 12 ++-- ...test_flash_llama_grammar_tools_choice.json | 6 +- ...test_flash_llama_grammar_tools_stream.json | 27 ++++++++ integration-tests/models/test_tools_llama.py | 41 ++++++++++-- router/src/lib.rs | 43 ++++++++++--- router/src/server.rs | 14 ++++- 10 files changed, 218 insertions(+), 38 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 51ebbc24..09660de3 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -3,7 +3,7 @@ import requests from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Dict, Optional, List, AsyncIterator, Iterator +from typing import Dict, Optional, List, AsyncIterator, Iterator, Union from text_generation.types import ( StreamResponse, @@ -12,6 +12,7 @@ from text_generation.types import ( Parameters, Grammar, ChatRequest, + ChatCompletionChunk, ChatComplete, Message, Tool, @@ -134,18 +135,42 @@ class Client: tools=tools, tool_choice=tool_choice, ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return ChatComplete(**payload) + else: + return self._chat_stream_response(request) + def _chat_stream_response(self, request): resp = requests.post( f"{self.base_url}/v1/chat/completions", json=request.dict(), headers=self.headers, cookies=self.cookies, timeout=self.timeout, + stream=True, ) - payload = resp.json() - if resp.status_code != 200: - raise parse_error(resp.status_code, payload) - return ChatComplete(**payload) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) def generate( self, @@ -417,7 +442,7 @@ class AsyncClient: top_p: Optional[float] = None, tools: Optional[List[Tool]] = None, tool_choice: Optional[str] = None, - ): + ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: """ Given a list of messages, generate a response asynchronously @@ -472,6 +497,12 @@ class AsyncClient: tools=tools, tool_choice=tool_choice, ) + if not stream: + return await self._chat_single_response(request) + else: + return self._chat_stream_response(request) + + async def _chat_single_response(self, request): async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: @@ -483,6 +514,25 @@ class AsyncClient: raise parse_error(resp.status, payload) return ChatComplete(**payload) + async def _chat_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + async def generate( self, prompt: str, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 8ca46654..4a308cef 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -59,6 +59,40 @@ class ChatCompletionComplete(BaseModel): usage: Any +class Function(BaseModel): + name: Optional[str] + arguments: str + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: str + type: str + function: Function + + +class ChoiceDelta(BaseModel): + role: str + content: Optional[str] + tool_calls: Optional[ChoiceDeltaToolCall] + + +class Choice(BaseModel): + index: int + delta: ChoiceDelta + logprobs: Optional[dict] = None + finish_reason: Optional[str] = None + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] + + class ChatComplete(BaseModel): # Chat completion details id: str diff --git a/docs/source/guidance.md b/docs/source/guidance.md index 2e9bbec5..42a12371 100644 --- a/docs/source/guidance.md +++ b/docs/source/guidance.md @@ -1,8 +1,8 @@ # Guidance -Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version `1.4.3`. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility. +Text Generation Inference (TGI) now supports [Grammar and Constraints](#grammar-and-constraints) and [Tools and Functions](#tools-and-functions) to help developer guide the LLM's responses and enhance its capabilities. -Whether you're a developer, a data scientist, or just a curious mind, we've made it super easy (and fun!) to start integrating advanced text generation capabilities into your applications. +These feature is available starting from version `1.4.3`. These features are accessible via the text-generation-client library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! ### Quick Start @@ -214,7 +214,6 @@ if __name__ == "__main__": ``` - ## Tools and Functions 🛠️ ### The Tools Parameter diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index a89501ca..9b9e33c6 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -14,8 +14,8 @@ "name": "tools", "parameters": { "format": "celsius", - "location": "San Francisco", - "num_days": 2 + "location": "New York, NY", + "num_days": 14 } }, "id": 0, @@ -25,14 +25,14 @@ "usage": null } ], - "created": 1708957016, + "created": 1709079417, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "1.4.2-native", "usage": { - "completion_tokens": 36, - "prompt_tokens": 313, - "total_tokens": 349 + "completion_tokens": 29, + "prompt_tokens": 316, + "total_tokens": 345 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json index a89501ca..de32c970 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -14,8 +14,8 @@ "name": "tools", "parameters": { "format": "celsius", - "location": "San Francisco", - "num_days": 2 + "location": "New York, NY", + "num_days": 14 } }, "id": 0, @@ -25,14 +25,14 @@ "usage": null } ], - "created": 1708957016, + "created": 1709079492, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "1.4.2-native", "usage": { - "completion_tokens": 36, - "prompt_tokens": 313, - "total_tokens": 349 + "completion_tokens": 29, + "prompt_tokens": 316, + "total_tokens": 345 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json index 83642258..3551e205 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -24,14 +24,14 @@ "usage": null } ], - "created": 1708957017, + "created": 1709079493, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "1.4.2-native", "usage": { "completion_tokens": 21, - "prompt_tokens": 184, - "total_tokens": 205 + "prompt_tokens": 187, + "total_tokens": 208 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json new file mode 100644 index 00000000..c367cc6f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "", + "name": null + }, + "id": "", + "index": 20, + "type": "function" + } + }, + "finish_reason": "eos_token", + "index": 20, + "logprobs": null + } + ], + "created": 1709087088, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native" +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index ecabf534..0901c7ac 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -124,8 +124,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna "name": "tools", "parameters": { "format": "celsius", - "location": "San Francisco", - "num_days": 2, + "location": "New York, NY", + "num_days": 14, }, }, "id": 0, @@ -163,8 +163,8 @@ async def test_flash_llama_grammar_tools_auto( "name": "tools", "parameters": { "format": "celsius", - "location": "San Francisco", - "num_days": 2, + "location": "New York, NY", + "num_days": 14, }, }, "id": 0, @@ -206,3 +206,36 @@ async def test_flash_llama_grammar_tools_choice( }, } assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Paris, France?", + }, + ], + stream=True, + ) + + count = 0 + async for response in responses: + print(response) + count += 1 + + assert count == 20 + assert response == response_snapshot diff --git a/router/src/lib.rs b/router/src/lib.rs index 98424497..d89bacb5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -415,15 +415,35 @@ pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionDelta { #[schema(example = "user")] pub role: String, + #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "What is Deep Learning?")] - pub content: String, + pub content: Option, + // default to None + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, } +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub(crate) struct DeltaToolCall { + pub index: u32, + pub id: String, + pub r#type: String, + pub function: Function, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub(crate) struct Function { + pub name: Option, + pub arguments: String, +} + +#[allow(clippy::too_many_arguments)] impl ChatCompletionChunk { pub(crate) fn new( model: String, system_fingerprint: String, - delta: String, + delta: Option, + tool_calls: Option>, created: u64, index: u32, logprobs: Option, @@ -440,6 +460,15 @@ impl ChatCompletionChunk { delta: ChatCompletionDelta { role: "assistant".to_string(), content: delta, + tool_calls: tool_calls.map(|tc| DeltaToolCall { + index, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: None, + arguments: tc[0].to_string(), + }, + }), }, logprobs, finish_reason, @@ -626,8 +655,8 @@ where state.end() } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -pub(crate) struct Function { +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] +pub(crate) struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, @@ -640,7 +669,7 @@ pub(crate) struct Tool { #[schema(example = "function")] pub r#type: String, // Grab the tool as generic JSON for debugging purposes. - pub function: Function, + pub function: FunctionDefinition, } #[derive(Clone, Serialize, Deserialize)] @@ -651,11 +680,11 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, } -#[derive(Clone, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] pub(crate) struct ToolCall { pub id: u32, pub r#type: String, - pub function: Function, + pub function: FunctionDefinition, } #[derive(Clone, Deserialize, ToSchema, Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index e3254625..2efa9284 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,7 +10,7 @@ use crate::{ HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; -use crate::{Function, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; +use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -699,11 +699,19 @@ async fn chat_completions( ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) }); + // replace the content with the tool calls if grammar is present + let (content, tool_calls) = if tool_grammar.is_some() { + (None, Some(vec![stream_token.token.text])) + } else { + (Some(stream_token.token.text), None) + }; + event .json_data(ChatCompletionChunk::new( model_id.clone(), system_fingerprint.clone(), - stream_token.token.text, + content, + tool_calls, current_time, stream_token.index, logprobs, @@ -756,7 +764,7 @@ async fn chat_completions( let tool_call = Some(ToolCall { id: 0, r#type: "function".to_string(), - function: Function { + function: FunctionDefinition { description: None, name: "tools".to_string(), parameters: gen_text_value.get("function").map_or_else(