2023-01-31 16:04:00 +00:00
|
|
|
/// Batching and inference logic
|
|
|
|
use crate::validation::{Validation, ValidationError};
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
use crate::HubTokenizerConfig;
|
|
|
|
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
2023-02-02 13:59:27 +00:00
|
|
|
use crate::{Entry, Queue, Token};
|
2023-03-09 14:30:54 +00:00
|
|
|
use futures::future::try_join_all;
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
use minijinja::{Environment, ErrorKind, Template};
|
2023-01-31 16:04:00 +00:00
|
|
|
use nohash_hasher::IntMap;
|
2023-04-26 18:23:54 +00:00
|
|
|
use std::sync::{
|
|
|
|
atomic::{AtomicBool, Ordering},
|
|
|
|
Arc,
|
|
|
|
};
|
2023-01-31 16:04:00 +00:00
|
|
|
use text_generation_client::{
|
2023-12-11 11:46:30 +00:00
|
|
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
|
2023-01-31 16:04:00 +00:00
|
|
|
};
|
|
|
|
use thiserror::Error;
|
2023-10-23 13:51:12 +00:00
|
|
|
use tokio::sync::mpsc::error::SendError;
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
2023-01-31 16:04:00 +00:00
|
|
|
use tokio::time::Instant;
|
2023-10-23 13:51:12 +00:00
|
|
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
|
|
use tokio_stream::StreamExt;
|
2023-02-13 12:02:45 +00:00
|
|
|
use tracing::{info_span, instrument, Instrument, Span};
|
2023-01-31 16:04:00 +00:00
|
|
|
|
|
|
|
/// Inference struct
|
|
|
|
#[derive(Clone)]
|
|
|
|
pub struct Infer {
|
|
|
|
/// Validation
|
|
|
|
validation: Validation,
|
2023-02-02 13:59:27 +00:00
|
|
|
/// Request queue
|
|
|
|
queue: Queue,
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Shared state
|
|
|
|
shared: Arc<Shared>,
|
|
|
|
/// Inference limit
|
|
|
|
limit_concurrent_requests: Arc<Semaphore>,
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
/// Chat template
|
|
|
|
template: Option<Template<'static, 'static>>,
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Infer shared state
|
|
|
|
struct Shared {
|
|
|
|
/// Batching background Tokio task notifier
|
|
|
|
batching_task: Notify,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Infer {
|
2023-04-26 18:23:54 +00:00
|
|
|
#[allow(clippy::too_many_arguments)]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub(crate) fn new(
|
|
|
|
client: ShardedClient,
|
|
|
|
validation: Validation,
|
2023-04-24 15:59:00 +00:00
|
|
|
waiting_served_ratio: f32,
|
2023-06-30 17:09:59 +00:00
|
|
|
max_batch_prefill_tokens: u32,
|
2023-04-24 15:59:00 +00:00
|
|
|
max_batch_total_tokens: u32,
|
2023-01-31 16:04:00 +00:00
|
|
|
max_waiting_tokens: usize,
|
|
|
|
max_concurrent_requests: usize,
|
2023-04-24 15:59:00 +00:00
|
|
|
requires_padding: bool,
|
2023-09-28 07:55:47 +00:00
|
|
|
window_size: Option<u32>,
|
2023-12-11 11:46:30 +00:00
|
|
|
speculate: u32,
|
2023-04-26 18:23:54 +00:00
|
|
|
generation_health: Arc<AtomicBool>,
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
tokenizer_config: HubTokenizerConfig,
|
2023-01-31 16:04:00 +00:00
|
|
|
) -> Self {
|
|
|
|
// Infer shared state
|
2023-12-11 11:46:30 +00:00
|
|
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
2023-01-31 16:04:00 +00:00
|
|
|
let shared = Arc::new(Shared {
|
|
|
|
batching_task: Notify::new(),
|
|
|
|
});
|
|
|
|
|
|
|
|
// Spawn batching background task that contains all the inference logic
|
|
|
|
tokio::spawn(batching_task(
|
|
|
|
client,
|
2023-04-24 15:59:00 +00:00
|
|
|
waiting_served_ratio,
|
2023-06-30 17:09:59 +00:00
|
|
|
max_batch_prefill_tokens,
|
2023-04-24 15:59:00 +00:00
|
|
|
max_batch_total_tokens,
|
2023-01-31 16:04:00 +00:00
|
|
|
max_waiting_tokens,
|
2023-02-02 13:59:27 +00:00
|
|
|
queue.clone(),
|
2023-01-31 16:04:00 +00:00
|
|
|
shared.clone(),
|
2023-04-26 18:23:54 +00:00
|
|
|
generation_health,
|
2023-01-31 16:04:00 +00:00
|
|
|
));
|
|
|
|
|
|
|
|
// Inference limit with a semaphore
|
|
|
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
|
|
|
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
let template = tokenizer_config.chat_template.map(|t| {
|
|
|
|
let env = Box::new(Environment::new());
|
|
|
|
let template_str = t.into_boxed_str();
|
|
|
|
// leaking env and template_str as read-only, static resources for performance.
|
|
|
|
Box::leak(env)
|
|
|
|
.template_from_str(Box::leak(template_str))
|
|
|
|
.unwrap()
|
|
|
|
});
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
Self {
|
|
|
|
validation,
|
2023-02-02 13:59:27 +00:00
|
|
|
queue,
|
2023-01-31 16:04:00 +00:00
|
|
|
shared,
|
|
|
|
limit_concurrent_requests: semaphore,
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
template,
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-02 13:59:27 +00:00
|
|
|
/// Add a new request to the queue and return a stream of InferStreamResponse
|
2023-11-20 09:33:44 +00:00
|
|
|
#[instrument(skip_all)]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub(crate) async fn generate_stream(
|
|
|
|
&self,
|
|
|
|
request: GenerateRequest,
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
) -> Result<GenerateStreamResponse, InferError> {
|
2023-01-31 16:04:00 +00:00
|
|
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
2023-02-13 12:02:45 +00:00
|
|
|
let permit = self
|
|
|
|
.clone()
|
|
|
|
.limit_concurrent_requests
|
|
|
|
.try_acquire_owned()
|
|
|
|
.map_err(|err| {
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
2023-02-13 12:02:45 +00:00
|
|
|
tracing::error!("{err}");
|
|
|
|
err
|
|
|
|
})?;
|
2023-01-31 16:04:00 +00:00
|
|
|
|
|
|
|
// Validate request
|
2023-04-09 18:22:27 +00:00
|
|
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
|
|
|
tracing::error!("{err}");
|
|
|
|
err
|
|
|
|
})?;
|
2023-01-31 16:04:00 +00:00
|
|
|
|
|
|
|
// MPSC channel to communicate with the background batching task
|
2023-10-23 13:51:12 +00:00
|
|
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
2024-01-11 18:01:43 +00:00
|
|
|
let input_length = valid_request.input_length;
|
2023-01-31 16:04:00 +00:00
|
|
|
|
2023-02-02 13:59:27 +00:00
|
|
|
// Append the request to the queue
|
|
|
|
self.queue.append(Entry {
|
2023-01-31 16:04:00 +00:00
|
|
|
request: valid_request,
|
|
|
|
response_tx,
|
2023-02-13 12:02:45 +00:00
|
|
|
span: Span::current(),
|
|
|
|
temp_span: None,
|
|
|
|
queue_time: Instant::now(),
|
2023-01-31 16:04:00 +00:00
|
|
|
batch_time: None,
|
|
|
|
});
|
|
|
|
|
2023-02-02 13:59:27 +00:00
|
|
|
// Notify the background task that we have a new entry in the queue that needs
|
2023-01-31 16:04:00 +00:00
|
|
|
// to be batched
|
|
|
|
self.shared.batching_task.notify_one();
|
|
|
|
|
|
|
|
// Return stream
|
2024-01-11 18:01:43 +00:00
|
|
|
Ok((
|
|
|
|
permit,
|
|
|
|
input_length,
|
|
|
|
UnboundedReceiverStream::new(response_rx),
|
|
|
|
))
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
/// Apply the chat template to the chat request
|
|
|
|
#[instrument(skip_all)]
|
|
|
|
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
|
|
|
self.template
|
|
|
|
.as_ref()
|
|
|
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
|
|
|
.render(chat)
|
|
|
|
.map_err(|e| {
|
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
|
|
|
tracing::error!("{e}");
|
|
|
|
InferError::TemplateError(e)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2023-02-02 13:59:27 +00:00
|
|
|
/// Add a new request to the queue and return a InferResponse
|
2023-11-20 09:33:44 +00:00
|
|
|
#[instrument(skip_all)]
|
2023-01-31 16:04:00 +00:00
|
|
|
pub(crate) async fn generate(
|
|
|
|
&self,
|
|
|
|
request: GenerateRequest,
|
|
|
|
) -> Result<InferResponse, InferError> {
|
2023-08-28 09:43:47 +00:00
|
|
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
|
|
|
|
2023-04-20 09:07:40 +00:00
|
|
|
// Create stream and keep semaphore permit as long as generate lives
|
2024-01-11 18:01:43 +00:00
|
|
|
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
2023-01-31 16:04:00 +00:00
|
|
|
|
|
|
|
// Return values
|
|
|
|
let mut result_prefill = Vec::new();
|
|
|
|
let mut result_tokens = Vec::new();
|
2023-08-28 09:43:47 +00:00
|
|
|
let mut result_top_tokens = Vec::new();
|
2023-01-31 16:04:00 +00:00
|
|
|
let mut result_generated_text = None;
|
|
|
|
let mut result_start = None;
|
|
|
|
let mut result_queued = None;
|
|
|
|
|
|
|
|
// Iterate on stream
|
|
|
|
while let Some(response) = stream.next().await {
|
|
|
|
match response? {
|
|
|
|
// Add prefill tokens
|
|
|
|
InferStreamResponse::Prefill(tokens) => {
|
|
|
|
// Create Token objects
|
|
|
|
// We do that here instead of in the Python code as Rust for loops are faster
|
|
|
|
result_prefill = tokens
|
|
|
|
.ids
|
|
|
|
.into_iter()
|
|
|
|
.zip(tokens.logprobs.into_iter())
|
|
|
|
.zip(tokens.texts.into_iter())
|
2023-02-24 14:55:57 +00:00
|
|
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
2023-01-31 16:04:00 +00:00
|
|
|
.collect();
|
|
|
|
}
|
|
|
|
// Push last token
|
2023-08-28 09:43:47 +00:00
|
|
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
|
|
|
result_tokens.push(token);
|
|
|
|
result_top_tokens.push(top_tokens);
|
|
|
|
}
|
2023-01-31 16:04:00 +00:00
|
|
|
// Final message
|
|
|
|
// Set return values
|
|
|
|
InferStreamResponse::End {
|
|
|
|
token,
|
|
|
|
generated_text,
|
|
|
|
start,
|
|
|
|
queued,
|
2023-08-28 09:43:47 +00:00
|
|
|
top_tokens,
|
2023-01-31 16:04:00 +00:00
|
|
|
} => {
|
|
|
|
result_tokens.push(token);
|
2023-08-28 09:43:47 +00:00
|
|
|
result_top_tokens.push(top_tokens);
|
2023-01-31 16:04:00 +00:00
|
|
|
result_generated_text = Some(generated_text);
|
|
|
|
result_start = Some(start);
|
|
|
|
result_queued = Some(queued)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that we received a `InferStreamResponse::End` message
|
|
|
|
if let (Some(generated_text), Some(queued), Some(start)) =
|
|
|
|
(result_generated_text, result_queued, result_start)
|
|
|
|
{
|
|
|
|
Ok(InferResponse {
|
|
|
|
prefill: result_prefill,
|
2024-01-11 18:01:43 +00:00
|
|
|
_input_length,
|
2023-01-31 16:04:00 +00:00
|
|
|
tokens: result_tokens,
|
|
|
|
generated_text,
|
|
|
|
queued,
|
|
|
|
start,
|
2023-08-28 09:43:47 +00:00
|
|
|
top_tokens: if use_top_tokens {
|
|
|
|
result_top_tokens
|
|
|
|
} else {
|
|
|
|
Vec::new()
|
|
|
|
},
|
2023-01-31 16:04:00 +00:00
|
|
|
})
|
|
|
|
} else {
|
2023-02-13 12:02:45 +00:00
|
|
|
let err = InferError::IncompleteGeneration;
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
2023-02-13 12:02:45 +00:00
|
|
|
tracing::error!("{err}");
|
|
|
|
Err(err)
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
}
|
2023-03-09 14:30:54 +00:00
|
|
|
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
|
|
|
/// the highest log probability per token
|
2023-11-20 09:33:44 +00:00
|
|
|
#[instrument(skip(self, request))]
|
2023-03-09 14:30:54 +00:00
|
|
|
pub(crate) async fn generate_best_of(
|
|
|
|
&self,
|
|
|
|
request: GenerateRequest,
|
|
|
|
best_of: usize,
|
|
|
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
|
|
|
// validate best_of parameter separately
|
|
|
|
let best_of = self.validation.validate_best_of(best_of)?;
|
|
|
|
|
|
|
|
// create multiple generate requests
|
|
|
|
let mut infer_responses: Vec<InferResponse> =
|
|
|
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
|
|
|
|
|
|
|
// get the sequence with the highest log probability per token
|
|
|
|
let mut max_index = 0;
|
|
|
|
let mut max_logprob: f32 = f32::MIN;
|
|
|
|
|
|
|
|
for (i, response) in infer_responses.iter().enumerate() {
|
|
|
|
// mean logprobs of the generated tokens
|
|
|
|
let sequence_logprob = response
|
|
|
|
.tokens
|
|
|
|
.iter()
|
|
|
|
.map(|token| token.logprob)
|
|
|
|
.sum::<f32>()
|
|
|
|
/ response.tokens.len() as f32;
|
|
|
|
|
|
|
|
// set best sequence
|
|
|
|
if sequence_logprob > max_logprob {
|
|
|
|
max_index = i;
|
|
|
|
max_logprob = sequence_logprob;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
let best_response = infer_responses.remove(max_index);
|
|
|
|
Ok((best_response, infer_responses))
|
|
|
|
}
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Batching logic
|
|
|
|
/// Will be launched in a background Tokio task
|
|
|
|
///
|
|
|
|
/// Batches requests and sends them to the inference server
|
2023-06-30 17:09:59 +00:00
|
|
|
#[allow(clippy::too_many_arguments)]
|
2023-01-31 16:04:00 +00:00
|
|
|
async fn batching_task(
|
|
|
|
mut client: ShardedClient,
|
2023-04-24 15:59:00 +00:00
|
|
|
waiting_served_ratio: f32,
|
2023-06-30 17:09:59 +00:00
|
|
|
max_batch_prefill_tokens: u32,
|
2023-04-24 15:59:00 +00:00
|
|
|
max_batch_total_tokens: u32,
|
2023-01-31 16:04:00 +00:00
|
|
|
max_waiting_tokens: usize,
|
2023-02-02 13:59:27 +00:00
|
|
|
queue: Queue,
|
2023-01-31 16:04:00 +00:00
|
|
|
shared: Arc<Shared>,
|
2023-04-26 18:23:54 +00:00
|
|
|
generation_health: Arc<AtomicBool>,
|
2023-01-31 16:04:00 +00:00
|
|
|
) {
|
|
|
|
// Infinite loop
|
|
|
|
loop {
|
|
|
|
// Wait for a notification from the Infer struct
|
|
|
|
shared.batching_task.notified().await;
|
|
|
|
|
2023-02-02 13:59:27 +00:00
|
|
|
// Get the next batch from the queue
|
2023-01-31 16:04:00 +00:00
|
|
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
2023-02-02 13:59:27 +00:00
|
|
|
// waiting in the queue
|
2023-06-30 17:09:59 +00:00
|
|
|
while let Some((mut entries, batch, span)) = queue
|
|
|
|
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
|
|
|
|
.await
|
2023-04-24 15:59:00 +00:00
|
|
|
{
|
2023-04-26 18:23:54 +00:00
|
|
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
2023-02-13 12:02:45 +00:00
|
|
|
.instrument(span)
|
|
|
|
.await;
|
2023-01-31 16:04:00 +00:00
|
|
|
let mut waiting_tokens = 1;
|
|
|
|
|
|
|
|
// We loop until we do not receive any cached batch from the inference server (== until
|
|
|
|
// all requests have met their stopping criteria)
|
|
|
|
while let Some(batch) = cached_batch {
|
|
|
|
// Get current batch info
|
|
|
|
let batch_size = batch.size;
|
2023-04-24 15:59:00 +00:00
|
|
|
let batch_max_tokens = batch.max_tokens;
|
2023-01-31 16:04:00 +00:00
|
|
|
let mut batches = vec![batch];
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
2023-04-24 15:59:00 +00:00
|
|
|
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
|
|
|
|
|
|
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
|
|
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
|
|
// to add a new batch even though its size might be small
|
|
|
|
None
|
|
|
|
} else {
|
|
|
|
// Minimum batch size
|
|
|
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
|
|
|
};
|
|
|
|
|
2023-06-30 17:09:59 +00:00
|
|
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
2023-04-24 15:59:00 +00:00
|
|
|
|
|
|
|
// Try to get a new batch
|
2023-06-30 17:09:59 +00:00
|
|
|
if let Some((mut new_entries, new_batch, span)) = queue
|
|
|
|
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
|
|
|
|
.await
|
2023-04-24 15:59:00 +00:00
|
|
|
{
|
|
|
|
// Tracking metrics
|
|
|
|
if min_size.is_some() {
|
|
|
|
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
|
|
|
} else {
|
|
|
|
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
|
|
|
}
|
2023-01-31 16:04:00 +00:00
|
|
|
|
2023-04-24 15:59:00 +00:00
|
|
|
entries.iter_mut().for_each(|(_, entry)| {
|
|
|
|
// Create a new span to add the info that this entry is waiting
|
|
|
|
// because a new batch is being computed
|
|
|
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
|
|
|
// Add relationships
|
|
|
|
span.follows_from(&entry_waiting_span);
|
|
|
|
entry_waiting_span.follows_from(&span);
|
|
|
|
// Update entry
|
|
|
|
entry.temp_span = Some(entry_waiting_span);
|
|
|
|
});
|
|
|
|
|
|
|
|
// Generate one token for this new batch to have the attention past in cache
|
2023-04-26 18:23:54 +00:00
|
|
|
let new_cached_batch =
|
|
|
|
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
|
|
|
.instrument(span)
|
|
|
|
.await;
|
2023-04-24 15:59:00 +00:00
|
|
|
// Reset waiting counter
|
|
|
|
waiting_tokens = 1;
|
|
|
|
// Extend current batch with the new batch
|
|
|
|
if let Some(new_cached_batch) = new_cached_batch {
|
|
|
|
entries.extend(new_entries);
|
|
|
|
batches.push(new_cached_batch);
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
}
|
2023-04-24 15:59:00 +00:00
|
|
|
|
2023-02-13 12:02:45 +00:00
|
|
|
// Create span for this batch to add context to inference calls
|
|
|
|
let next_batch_size = entries.len();
|
|
|
|
let next_batch_span =
|
|
|
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
|
|
|
entries.iter_mut().for_each(|(_, entry)| {
|
|
|
|
// Create a new span to link the batch back to this entry
|
2023-04-20 09:07:40 +00:00
|
|
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
2023-03-16 11:12:26 +00:00
|
|
|
// Add relationships
|
|
|
|
next_batch_span.follows_from(&entry_batch_span);
|
2023-02-13 12:02:45 +00:00
|
|
|
entry_batch_span.follows_from(&next_batch_span);
|
|
|
|
// Update entry
|
|
|
|
entry.temp_span = Some(entry_batch_span);
|
|
|
|
});
|
2023-01-31 16:04:00 +00:00
|
|
|
|
2023-04-26 18:23:54 +00:00
|
|
|
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
2023-02-13 12:02:45 +00:00
|
|
|
.instrument(next_batch_span)
|
|
|
|
.await;
|
2023-01-31 16:04:00 +00:00
|
|
|
waiting_tokens += 1;
|
|
|
|
}
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::gauge!("tgi_batch_current_size", 0.0);
|
2023-04-24 15:59:00 +00:00
|
|
|
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-13 12:02:45 +00:00
|
|
|
#[instrument(skip_all)]
|
2023-02-16 16:18:53 +00:00
|
|
|
async fn prefill(
|
|
|
|
client: &mut ShardedClient,
|
|
|
|
batch: Batch,
|
2023-01-31 16:04:00 +00:00
|
|
|
entries: &mut IntMap<u64, Entry>,
|
2023-04-26 18:23:54 +00:00
|
|
|
generation_health: &Arc<AtomicBool>,
|
2023-05-24 17:19:57 +00:00
|
|
|
) -> Option<CachedBatch> {
|
2023-02-16 16:18:53 +00:00
|
|
|
let start_time = Instant::now();
|
2023-03-28 09:29:35 +00:00
|
|
|
let batch_id = batch.id;
|
2023-04-09 18:13:28 +00:00
|
|
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
2023-02-16 16:18:53 +00:00
|
|
|
|
|
|
|
match client.prefill(batch).await {
|
2023-12-14 14:59:38 +00:00
|
|
|
Ok((generations, next_batch, timings)) => {
|
2023-04-26 18:23:54 +00:00
|
|
|
// Update health
|
|
|
|
generation_health.store(true, Ordering::SeqCst);
|
2023-12-14 14:59:38 +00:00
|
|
|
|
|
|
|
let start_filtering_time = Instant::now();
|
2023-04-24 15:59:00 +00:00
|
|
|
// Send generated tokens and filter stopped entries
|
2023-04-20 09:07:40 +00:00
|
|
|
filter_send_generations(generations, entries);
|
|
|
|
|
|
|
|
// Filter next batch and remove requests that were stopped
|
2023-04-24 15:59:00 +00:00
|
|
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
2023-04-20 09:07:40 +00:00
|
|
|
|
2023-12-14 14:59:38 +00:00
|
|
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
|
|
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
|
|
|
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
2023-04-09 18:13:28 +00:00
|
|
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
|
|
|
next_batch
|
|
|
|
}
|
|
|
|
// If we have an error, we discard the whole batch
|
|
|
|
Err(err) => {
|
2023-04-26 18:23:54 +00:00
|
|
|
// Update health
|
|
|
|
generation_health.store(false, Ordering::SeqCst);
|
2023-03-28 09:29:35 +00:00
|
|
|
let _ = client.clear_cache(Some(batch_id)).await;
|
2023-02-16 16:18:53 +00:00
|
|
|
send_errors(err, entries);
|
|
|
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[instrument(skip_all)]
|
|
|
|
async fn decode(
|
|
|
|
client: &mut ShardedClient,
|
2023-05-24 17:19:57 +00:00
|
|
|
batches: Vec<CachedBatch>,
|
2023-02-16 16:18:53 +00:00
|
|
|
entries: &mut IntMap<u64, Entry>,
|
2023-04-26 18:23:54 +00:00
|
|
|
generation_health: &Arc<AtomicBool>,
|
2023-05-24 17:19:57 +00:00
|
|
|
) -> Option<CachedBatch> {
|
2023-02-16 16:18:53 +00:00
|
|
|
let start_time = Instant::now();
|
2023-04-20 09:07:40 +00:00
|
|
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
2023-04-09 18:13:28 +00:00
|
|
|
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
2023-02-16 16:18:53 +00:00
|
|
|
|
|
|
|
match client.decode(batches).await {
|
2023-12-14 14:59:38 +00:00
|
|
|
Ok((generations, next_batch, timings)) => {
|
2023-04-26 18:23:54 +00:00
|
|
|
// Update health
|
|
|
|
generation_health.store(true, Ordering::SeqCst);
|
2023-12-14 14:59:38 +00:00
|
|
|
|
|
|
|
let start_filtering_time = Instant::now();
|
2023-04-24 15:59:00 +00:00
|
|
|
// Send generated tokens and filter stopped entries
|
2023-04-20 09:07:40 +00:00
|
|
|
filter_send_generations(generations, entries);
|
|
|
|
|
|
|
|
// Filter next batch and remove requests that were stopped
|
2023-04-24 15:59:00 +00:00
|
|
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
2023-04-20 09:07:40 +00:00
|
|
|
|
2023-12-14 14:59:38 +00:00
|
|
|
if let Some(concat_duration) = timings.concat {
|
|
|
|
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
|
|
|
}
|
|
|
|
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
|
|
|
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
|
|
|
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
2023-04-09 18:13:28 +00:00
|
|
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
2023-01-31 16:04:00 +00:00
|
|
|
next_batch
|
|
|
|
}
|
|
|
|
// If we have an error, we discard the whole batch
|
|
|
|
Err(err) => {
|
2023-04-26 18:23:54 +00:00
|
|
|
generation_health.store(false, Ordering::SeqCst);
|
2023-04-20 09:07:40 +00:00
|
|
|
for id in batch_ids {
|
|
|
|
let _ = client.clear_cache(Some(id)).await;
|
|
|
|
}
|
2023-02-13 12:02:45 +00:00
|
|
|
send_errors(err, entries);
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
2023-01-31 16:04:00 +00:00
|
|
|
None
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-20 09:07:40 +00:00
|
|
|
/// Filter a `batch` and remove all requests not present in `entries`
|
|
|
|
#[instrument(skip_all)]
|
2023-04-24 15:59:00 +00:00
|
|
|
async fn filter_batch(
|
|
|
|
client: &mut ShardedClient,
|
2023-05-24 17:19:57 +00:00
|
|
|
next_batch: Option<CachedBatch>,
|
2023-04-24 15:59:00 +00:00
|
|
|
entries: &IntMap<u64, Entry>,
|
2023-05-24 17:19:57 +00:00
|
|
|
) -> Option<CachedBatch> {
|
2023-04-24 15:59:00 +00:00
|
|
|
let mut batch = next_batch?;
|
|
|
|
|
|
|
|
// No need to filter
|
|
|
|
if batch.size as usize == entries.len() {
|
|
|
|
return Some(batch);
|
|
|
|
}
|
|
|
|
|
|
|
|
let id = batch.id;
|
|
|
|
|
|
|
|
// Retain only requests that are still in entries
|
2023-05-24 17:19:57 +00:00
|
|
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
2023-04-24 15:59:00 +00:00
|
|
|
|
2023-05-24 17:19:57 +00:00
|
|
|
if batch.request_ids.is_empty() {
|
2023-04-24 15:59:00 +00:00
|
|
|
// All requests have been filtered out
|
|
|
|
// Next batch is now empty
|
|
|
|
// Clear it from the Python shards cache
|
|
|
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
|
|
|
client.clear_cache(Some(id)).await.unwrap();
|
|
|
|
None
|
|
|
|
} else {
|
|
|
|
// Filter Python shard cache
|
|
|
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
2023-05-24 17:19:57 +00:00
|
|
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
2023-04-20 09:07:40 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
|
|
|
/// and filter entries
|
|
|
|
#[instrument(skip_all)]
|
|
|
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
|
|
|
generations.into_iter().for_each(|generation| {
|
|
|
|
let id = generation.request_id;
|
|
|
|
// Get entry
|
|
|
|
// We can `expect` here as the request id should always be in the entries
|
|
|
|
let entry = entries
|
|
|
|
.get(&id)
|
|
|
|
.expect("ID not found in entries. This is a bug.");
|
|
|
|
|
|
|
|
// Create and enter a span to link this function back to the entry
|
|
|
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
|
|
|
// Send generation responses back to the infer task
|
|
|
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
|
|
|
// request and we need to stop generating hence why we unwrap_or(true)
|
|
|
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
2023-10-23 13:51:12 +00:00
|
|
|
tracing::error!("Entry response channel error.");
|
2023-04-20 09:07:40 +00:00
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
|
|
|
err
|
|
|
|
}).unwrap_or(true);
|
|
|
|
if stopped {
|
|
|
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Send responses through the `entry` response channel
|
|
|
|
fn send_responses(
|
|
|
|
generation: Generation,
|
|
|
|
entry: &Entry,
|
2023-10-23 13:51:12 +00:00
|
|
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
2023-06-23 12:58:28 +00:00
|
|
|
// Return directly if the channel is disconnected
|
2023-10-23 13:51:12 +00:00
|
|
|
if entry.response_tx.is_closed() {
|
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
2023-06-23 12:58:28 +00:00
|
|
|
return Ok(true);
|
|
|
|
}
|
|
|
|
|
2023-04-20 09:07:40 +00:00
|
|
|
let mut stopped = false;
|
|
|
|
|
|
|
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
|
|
|
// Send message
|
2023-10-23 13:51:12 +00:00
|
|
|
entry
|
|
|
|
.response_tx
|
|
|
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
2023-04-20 09:07:40 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Create last Token
|
2023-12-11 11:46:30 +00:00
|
|
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
|
|
|
let n = tokens_.ids.len();
|
|
|
|
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
|
|
|
let mut iterator = tokens_
|
|
|
|
.ids
|
|
|
|
.into_iter()
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
.zip(tokens_.logprobs)
|
|
|
|
.zip(tokens_.texts)
|
|
|
|
.zip(tokens_.is_special)
|
2023-12-11 11:46:30 +00:00
|
|
|
.enumerate()
|
|
|
|
.peekable();
|
|
|
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
|
|
|
let token = Token {
|
|
|
|
id,
|
|
|
|
text,
|
|
|
|
logprob,
|
|
|
|
special,
|
|
|
|
};
|
|
|
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
2023-08-28 09:43:47 +00:00
|
|
|
top_tokens_
|
|
|
|
.ids
|
2023-12-11 11:46:30 +00:00
|
|
|
.iter()
|
|
|
|
.zip(top_tokens_.logprobs.iter())
|
|
|
|
.zip(top_tokens_.texts.iter())
|
|
|
|
.zip(top_tokens_.is_special.iter())
|
|
|
|
.map(|(((&id, &logprob), text), &special)| Token {
|
2023-08-28 09:43:47 +00:00
|
|
|
id,
|
2023-12-11 11:46:30 +00:00
|
|
|
text: text.to_string(),
|
2023-08-28 09:43:47 +00:00
|
|
|
logprob,
|
|
|
|
special,
|
2023-12-11 11:46:30 +00:00
|
|
|
})
|
|
|
|
.collect()
|
|
|
|
} else {
|
|
|
|
vec![]
|
|
|
|
};
|
|
|
|
match (&generation.generated_text, iterator.peek()) {
|
|
|
|
(Some(generated_text), None) => {
|
|
|
|
// Generation has ended
|
|
|
|
stopped = true;
|
|
|
|
// Send message
|
|
|
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
|
|
|
token,
|
|
|
|
top_tokens,
|
|
|
|
generated_text: generated_text.clone(),
|
|
|
|
queued: entry.queue_time,
|
|
|
|
start: entry.batch_time.unwrap(),
|
|
|
|
}))?;
|
|
|
|
}
|
|
|
|
_ => {
|
|
|
|
// Send message
|
|
|
|
entry
|
|
|
|
.response_tx
|
|
|
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
|
|
|
}
|
|
|
|
}
|
2023-08-28 09:43:47 +00:00
|
|
|
}
|
|
|
|
|
2023-04-20 09:07:40 +00:00
|
|
|
Ok(stopped)
|
|
|
|
}
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
/// Send errors to Infer for all `entries`
|
2023-02-13 12:02:45 +00:00
|
|
|
#[instrument(skip_all)]
|
|
|
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
2023-01-31 16:04:00 +00:00
|
|
|
entries.drain().for_each(|(_, entry)| {
|
2023-02-13 12:02:45 +00:00
|
|
|
// Create and enter a span to link this function back to the entry
|
|
|
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
|
|
|
let err = InferError::GenerationError(error.to_string());
|
2023-02-16 16:18:53 +00:00
|
|
|
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
2023-02-13 12:02:45 +00:00
|
|
|
tracing::error!("{err}");
|
|
|
|
|
2023-01-31 16:04:00 +00:00
|
|
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
|
|
|
entry
|
|
|
|
.response_tx
|
2023-10-23 13:51:12 +00:00
|
|
|
.send(Err(err))
|
2023-01-31 16:04:00 +00:00
|
|
|
.unwrap_or(());
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub(crate) enum InferStreamResponse {
|
|
|
|
// Optional first message
|
2023-12-11 11:46:30 +00:00
|
|
|
Prefill(Tokens),
|
2023-01-31 16:04:00 +00:00
|
|
|
// Intermediate messages
|
2023-08-28 09:43:47 +00:00
|
|
|
Intermediate {
|
|
|
|
token: Token,
|
|
|
|
top_tokens: Vec<Token>,
|
|
|
|
},
|
2023-01-31 16:04:00 +00:00
|
|
|
// Last message
|
|
|
|
End {
|
|
|
|
token: Token,
|
2023-08-28 09:43:47 +00:00
|
|
|
top_tokens: Vec<Token>,
|
2023-01-31 16:04:00 +00:00
|
|
|
generated_text: GeneratedText,
|
|
|
|
start: Instant,
|
|
|
|
queued: Instant,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
pub(crate) struct InferResponse {
|
2024-01-11 18:01:43 +00:00
|
|
|
/// input_length is the input as perceived by the rust tokenizer in the
|
|
|
|
/// validation pathway. It is redundant with prefill.len() but prefill
|
|
|
|
/// has data only if the user asked for it. This will always be filled.
|
|
|
|
pub(crate) _input_length: u32,
|
2023-02-24 14:55:57 +00:00
|
|
|
pub(crate) prefill: Vec<PrefillToken>,
|
2023-01-31 16:04:00 +00:00
|
|
|
pub(crate) tokens: Vec<Token>,
|
|
|
|
pub(crate) generated_text: GeneratedText,
|
|
|
|
pub(crate) queued: Instant,
|
|
|
|
pub(crate) start: Instant,
|
2023-08-28 09:43:47 +00:00
|
|
|
pub(crate) top_tokens: Vec<Vec<Token>>,
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Error)]
|
|
|
|
pub enum InferError {
|
|
|
|
#[error("Request failed during generation: {0}")]
|
|
|
|
GenerationError(String),
|
|
|
|
#[error("Model is overloaded")]
|
|
|
|
Overloaded(#[from] TryAcquireError),
|
|
|
|
#[error("Input validation error: {0}")]
|
|
|
|
ValidationError(#[from] ValidationError),
|
|
|
|
#[error("Incomplete generation")]
|
|
|
|
IncompleteGeneration,
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
#[error("Template error: {0}")]
|
|
|
|
TemplateError(#[from] minijinja::Error),
|
2023-01-31 16:04:00 +00:00
|
|
|
}
|
2023-03-07 17:52:22 +00:00
|
|
|
|
|
|
|
impl InferError {
|
|
|
|
pub(crate) fn error_type(&self) -> &str {
|
|
|
|
match self {
|
|
|
|
InferError::GenerationError(_) => "generation",
|
|
|
|
InferError::Overloaded(_) => "overloaded",
|
|
|
|
InferError::ValidationError(_) => "validation",
|
|
|
|
InferError::IncompleteGeneration => "incomplete_generation",
|
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.
Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)
General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning
# How to test
### Streaming curl
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
It is also possible to use the `openai` python library and change the
base url
### 🌊 STREAMING REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
# ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='')
```
### 🚗 SYNCHRONOUS REQUEST
```python
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="not needed for a local LLM"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=False
)
print(chat_completion)
# ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176))
```
## How to run dev
```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```
***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```
trigger
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
-H 'Content-Type: application/json'
```
^ supports `stream: true` and `stream: false` requests
2024-01-16 10:07:41 +00:00
|
|
|
InferError::TemplateError(_) => "template_error",
|
2023-03-07 17:52:22 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|