mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Add simple continuous batching benchmark (#108)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
This commit is contained in:
parent
7f58680999
commit
9796b0e10d
@ -55,20 +55,22 @@ To use [🤗 text-generation-inference](https://github.com/huggingface/text-gene
|
|||||||
|
|
||||||
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model --sharded true --num-shard 8
|
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model --sharded true --num-shard 8
|
||||||
```
|
```
|
||||||
4. You can then send a simple request:
|
3. You can then send a simple request:
|
||||||
```bash
|
```bash
|
||||||
curl 127.0.0.1:8080/generate \
|
curl 127.0.0.1:8080/generate \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
|
||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
5. To run static benchmark test, please refer to [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark).
|
4. To run static benchmark test, please refer to [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark).
|
||||||
|
|
||||||
To run it on the same machine, you can do the following:
|
To run it on the same machine, you can do the following:
|
||||||
* `docker exec -it <docker name> bash` , pick the docker started from step 2 using docker ps
|
* `docker exec -it <docker name> bash` , pick the docker started from step 2 using docker ps
|
||||||
* `text-generation-benchmark -t <model-id>` , pass the model-id from docker run command
|
* `text-generation-benchmark -t <model-id>` , pass the model-id from docker run command
|
||||||
* after the completion of tests, hit ctrl+c to see the performance data summary.
|
* after the completion of tests, hit ctrl+c to see the performance data summary.
|
||||||
|
|
||||||
|
5. To run continuous batching test, please refer to [examples](https://github.com/huggingface/tgi-gaudi/tree/habana-main/examples).
|
||||||
|
|
||||||
## Adjusting TGI parameters
|
## Adjusting TGI parameters
|
||||||
|
|
||||||
Maximum sequence length is controlled by two arguments:
|
Maximum sequence length is controlled by two arguments:
|
||||||
|
39
examples/README.md
Normal file
39
examples/README.md
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# TGI-Gaudi example
|
||||||
|
|
||||||
|
This example provide a simple way of usage of `tgi-gaudi` with continuous batching. It uses a small dataset [DIBT/10k_prompts_ranked](https://huggingface.co/datasets/DIBT/10k_prompts_ranked) and present basic performance numbers.
|
||||||
|
|
||||||
|
## Get started
|
||||||
|
|
||||||
|
### Install
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setup TGI server
|
||||||
|
|
||||||
|
More details on runing the TGI server available [here](https://github.com/huggingface/tgi-gaudi/blob/habana-main/README.md#running-tgi-on-gaudi).
|
||||||
|
|
||||||
|
### Run benchmark
|
||||||
|
|
||||||
|
To run benchmark use below command:
|
||||||
|
|
||||||
|
```
|
||||||
|
python run_generation --model_id MODEL_ID
|
||||||
|
```
|
||||||
|
where `MODEL_ID` should be set to the same value as in the TGI server instance.
|
||||||
|
> For gated models such as [LLama](https://huggingface.co/meta-llama) or [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to set environment variable `HUGGING_FACE_HUB_TOKEN=<token>` with a valid Hugging Face Hub read token.
|
||||||
|
|
||||||
|
All possible parameters are described in the below table:
|
||||||
|
<div align="left">
|
||||||
|
|
||||||
|
| Name | Default value | Description |
|
||||||
|
| ------------------------- | :---------------------------- | :------------------------------------------------------------ |
|
||||||
|
| SERVER_ADDRESS | http://localhost:8080 | The address and port at which the TGI server is available. |
|
||||||
|
| MODEL_ID | meta-llama/Llama-2-7b-chat-hf | Model ID used in the TGI server instance. |
|
||||||
|
| MAX_INPUT_LENGTH | 1024 | Maximum input length supported by the TGI server. |
|
||||||
|
| MAX_OUTPUT_LENGTH | 1024 | Maximum output length supported by the TGI server. |
|
||||||
|
| TOTAL_SAMPLE_COUNT | 2048 | Number of samples to run. |
|
||||||
|
| MAX_CONCURRENT_REQUESTS | 256 | The number of requests sent simultaneously to the TGI server. |
|
||||||
|
|
||||||
|
</div>
|
4
examples/requirements.txt
Normal file
4
examples/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
huggingface_hub==0.20.3
|
||||||
|
requests==2.31.0
|
||||||
|
datasets==2.18.0
|
||||||
|
transformers>=4.37.0
|
90
examples/run_generation.py
Normal file
90
examples/run_generation.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tgi_client import TgiClient
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--server_address", type=str, default="http://localhost:8080", help="Address of the TGI server"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model id used in TGI server"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_input_length", type=int, default=1024, help="Max input length for TGI model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_output_length", type=int, default=1024, help="Max output length for TGI model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--total_sample_count", type=int, default=2048, help="Total number of samples to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests"
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def read_dataset(
|
||||||
|
max_input_length: int,
|
||||||
|
total_sample_count: int,
|
||||||
|
model_id: str
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked
|
||||||
|
and filters out too long samples.
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
dataset = load_dataset("DIBT/10k_prompts_ranked", split="train", trust_remote_code=True)
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda x: len(tokenizer(x["prompt"])["input_ids"]) < max_input_length
|
||||||
|
)
|
||||||
|
if len(dataset) > total_sample_count:
|
||||||
|
dataset = dataset.select(range(total_sample_count))
|
||||||
|
dataset = dataset.shuffle()
|
||||||
|
return [sample["prompt"] for sample in dataset]
|
||||||
|
|
||||||
|
|
||||||
|
def is_tgi_available(
|
||||||
|
server_address: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if TGI server is available under the specified address.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
info = requests.get(f"{server_address}/info")
|
||||||
|
return info.status_code == 200
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
dataset = read_dataset(
|
||||||
|
args.max_input_length, args.total_sample_count, args.model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_tgi_available(args.server_address):
|
||||||
|
raise RuntimeError("Cannot connect with TGI server!")
|
||||||
|
|
||||||
|
tgi_client = TgiClient(
|
||||||
|
args.server_address, args.max_concurrent_requests
|
||||||
|
)
|
||||||
|
timestamp = time.perf_counter_ns()
|
||||||
|
tgi_client.run_generation(
|
||||||
|
dataset, args.max_output_length
|
||||||
|
)
|
||||||
|
duration_s = (time.perf_counter_ns() - timestamp) * 1e-9
|
||||||
|
tgi_client.print_performance_metrics(duration_s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
100
examples/tgi_client.py
Normal file
100
examples/tgi_client.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import os
|
||||||
|
import statistics
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
|
||||||
|
def except_hook(args):
|
||||||
|
print(f"Thread failed with error: {args.exc_value}")
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
|
threading.excepthook = except_hook
|
||||||
|
|
||||||
|
|
||||||
|
class TgiClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_address: str,
|
||||||
|
max_num_threads: int
|
||||||
|
) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._semaphore = threading.Semaphore(max_num_threads)
|
||||||
|
self._client = InferenceClient(server_address)
|
||||||
|
|
||||||
|
self._ttft = []
|
||||||
|
self._tpot = []
|
||||||
|
self._generated_tokens = []
|
||||||
|
|
||||||
|
def run_generation(
|
||||||
|
self,
|
||||||
|
samples: List[str],
|
||||||
|
max_new_tokens: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run generation for every sample in dataset.
|
||||||
|
Creates a separate thread for every sample.
|
||||||
|
"""
|
||||||
|
threads: List[threading.Thread] = []
|
||||||
|
for sample in tqdm.tqdm(samples):
|
||||||
|
self._semaphore.acquire()
|
||||||
|
threads.append(
|
||||||
|
threading.Thread(
|
||||||
|
target=self._process_sample, args=[sample, max_new_tokens]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
threads[-1].start()
|
||||||
|
for thread in threads:
|
||||||
|
if thread is not None:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
def _process_sample(
|
||||||
|
self,
|
||||||
|
sample: str,
|
||||||
|
max_new_tokens: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generates response stream for a single sample.
|
||||||
|
Collects performance metrics.
|
||||||
|
"""
|
||||||
|
timestamp = time.perf_counter_ns()
|
||||||
|
response_stream = self._client.text_generation(
|
||||||
|
sample, max_new_tokens=max_new_tokens, stream=True, details=True
|
||||||
|
)
|
||||||
|
out = ''
|
||||||
|
for id, response in enumerate(response_stream):
|
||||||
|
if id == 0:
|
||||||
|
self._ttft.append(time.perf_counter_ns() - timestamp)
|
||||||
|
else:
|
||||||
|
self._tpot.append(time.perf_counter_ns() - timestamp)
|
||||||
|
timestamp = time.perf_counter_ns()
|
||||||
|
out += response.token.text
|
||||||
|
if response.details:
|
||||||
|
self._generated_tokens.append(response.details.generated_tokens)
|
||||||
|
|
||||||
|
self._semaphore.release()
|
||||||
|
|
||||||
|
def print_performance_metrics(
|
||||||
|
self,
|
||||||
|
duration_s: float
|
||||||
|
) -> None:
|
||||||
|
def line():
|
||||||
|
print(32*"-")
|
||||||
|
|
||||||
|
line()
|
||||||
|
print("----- Performance summary -----")
|
||||||
|
line()
|
||||||
|
print(f"Throughput: {sum(self._generated_tokens) / duration_s:.1f} tokens/s")
|
||||||
|
print(f"Throughput: {len(self._generated_tokens) / duration_s:.1f} queries/s")
|
||||||
|
line()
|
||||||
|
print(f"First token latency:")
|
||||||
|
print(f"\tMedian: \t{statistics.median(self._ttft)*1e-6:.2f}ms")
|
||||||
|
print(f"\tAverage: \t{statistics.fmean(self._ttft)*1e-6:.2f}ms")
|
||||||
|
line()
|
||||||
|
print(f"Output token latency:")
|
||||||
|
print(f"\tMedian: \t{statistics.median(self._tpot)*1e-6:.2f}ms")
|
||||||
|
print(f"\tAverage: \t{statistics.fmean(self._tpot)*1e-6:.2f}ms")
|
||||||
|
line()
|
Loading…
Reference in New Issue
Block a user