mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 17:52:09 +00:00
This PR adds support to make TGI a drop in replacement for OpenAI clients by exposing the same HTTP interface. Notes - TGI inits a single model at startup so the `model` field is unused in HTTP requests. - `max_tokens` and `stream` should work as expected but other params may be (unimplemented or not supported) General approach - fetch the `tokenizer_config` at startup from the hub - pass `tokenizer_config` into `Infer` so we have it at request time - use the `chat_template` on the config to format chat request - parse jinja template and render chat string - pass inputs into existing generate function - wrap generation output in expected structure before returning # How to test ### Streaming curl ```bash curl localhost:3000/v1/chat/completions \ -X POST \ -d '{ "model": "tgi", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is deep learning?" } ], "stream": true, "max_tokens": 20 }' \ -H 'Content-Type: application/json' ``` It is also possible to use the `openai` python library and change the base url ### 🌊 STREAMING REQUEST ```python from openai import OpenAI # init the client but point it to TGI client = OpenAI( base_url="http://localhost:3000/v1", api_key="not needed for a local LLM" ) chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "system", "content": "You are a helpful assistant." }, {"role": "user", "content": "What is deep learning?"} ], stream=True ) # iterate and print stream for message in chat_completion: print(message) # ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='') ``` ### 🚗 SYNCHRONOUS REQUEST ```python from openai import OpenAI # init the client but point it to TGI client = OpenAI( base_url="http://localhost:3000/v1", api_key="not needed for a local LLM" ) chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "system", "content": "You are a helpful assistant." }, {"role": "user", "content": "What is deep learning?"} ], stream=False ) print(chat_completion) # ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176)) ``` ## How to run dev ```bash cd text-generation-inference/server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2 ``` ***note many of the existing `chat_templates` use non standard `jinja` (ie. adding a `raise` to the template) which will throw an error when parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a valid template ```bash cd text-generation-inference/router cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0 ``` trigger ```bash curl localhost:3000/v1/chat/completions \ -X POST \ -d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \ -H 'Content-Type: application/json' ``` ^ supports `stream: true` and `stream: false` requests
55 lines
1.7 KiB
TOML
55 lines
1.7 KiB
TOML
[package]
|
|
name = "text-generation-router"
|
|
description = "Text Generation Webserver"
|
|
build = "build.rs"
|
|
version.workspace = true
|
|
edition.workspace = true
|
|
authors.workspace = true
|
|
homepage.workspace = true
|
|
|
|
[lib]
|
|
path = "src/lib.rs"
|
|
|
|
[[bin]]
|
|
name = "text-generation-router"
|
|
path = "src/main.rs"
|
|
|
|
[dependencies]
|
|
async-stream = "0.3.5"
|
|
axum = { version = "0.6.20", features = ["json"] }
|
|
axum-tracing-opentelemetry = "0.14.1"
|
|
text-generation-client = { path = "client" }
|
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
|
futures = "0.3.28"
|
|
hf-hub = { version = "0.3.0", features = ["tokio"] }
|
|
metrics = "0.21.1"
|
|
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
|
nohash-hasher = "0.2.0"
|
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
|
opentelemetry-otlp = "0.13.0"
|
|
rand = "0.8.5"
|
|
reqwest = { version = "0.11.20", features = [] }
|
|
serde = "1.0.188"
|
|
serde_json = "1.0.107"
|
|
thiserror = "1.0.48"
|
|
tokenizers = { version = "0.14.0", features = ["http"] }
|
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
|
tokio-stream = "0.1.14"
|
|
tower-http = { version = "0.4.4", features = ["cors"] }
|
|
tracing = "0.1.37"
|
|
tracing-opentelemetry = "0.21.0"
|
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
|
utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
|
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
|
minijinja = "1.0.10"
|
|
futures-util = "0.3.30"
|
|
|
|
[build-dependencies]
|
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
|
|
|
[features]
|
|
default = ["ngrok"]
|
|
ngrok = ["dep:ngrok"]
|