Add simple continuous batching benchmark (#108)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
This commit is contained in:
Karol Damaszke 2024-03-26 09:17:55 +01:00 committed by GitHub
parent 7f58680999
commit 9796b0e10d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 237 additions and 2 deletions

View File

@ -55,20 +55,22 @@ To use [🤗 text-generation-inference](https://github.com/huggingface/text-gene
docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model --sharded true --num-shard 8 docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model --sharded true --num-shard 8
``` ```
4. You can then send a simple request: 3. You can then send a simple request:
```bash ```bash
curl 127.0.0.1:8080/generate \ curl 127.0.0.1:8080/generate \
-X POST \ -X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
5. To run static benchmark test, please refer to [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark). 4. To run static benchmark test, please refer to [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark).
To run it on the same machine, you can do the following: To run it on the same machine, you can do the following:
* `docker exec -it <docker name> bash` , pick the docker started from step 2 using docker ps * `docker exec -it <docker name> bash` , pick the docker started from step 2 using docker ps
* `text-generation-benchmark -t <model-id>` , pass the model-id from docker run command * `text-generation-benchmark -t <model-id>` , pass the model-id from docker run command
* after the completion of tests, hit ctrl+c to see the performance data summary. * after the completion of tests, hit ctrl+c to see the performance data summary.
5. To run continuous batching test, please refer to [examples](https://github.com/huggingface/tgi-gaudi/tree/habana-main/examples).
## Adjusting TGI parameters ## Adjusting TGI parameters
Maximum sequence length is controlled by two arguments: Maximum sequence length is controlled by two arguments:

39
examples/README.md Normal file
View File

@ -0,0 +1,39 @@
# TGI-Gaudi example
This example provide a simple way of usage of `tgi-gaudi` with continuous batching. It uses a small dataset [DIBT/10k_prompts_ranked](https://huggingface.co/datasets/DIBT/10k_prompts_ranked) and present basic performance numbers.
## Get started
### Install
```
pip install -r requirements
```
### Setup TGI server
More details on runing the TGI server available [here](https://github.com/huggingface/tgi-gaudi/blob/habana-main/README.md#running-tgi-on-gaudi).
### Run benchmark
To run benchmark use below command:
```
python run_generation --model_id MODEL_ID
```
where `MODEL_ID` should be set to the same value as in the TGI server instance.
> For gated models such as [LLama](https://huggingface.co/meta-llama) or [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to set environment variable `HUGGING_FACE_HUB_TOKEN=<token>` with a valid Hugging Face Hub read token.
All possible parameters are described in the below table:
<div align="left">
| Name | Default value | Description |
| ------------------------- | :---------------------------- | :------------------------------------------------------------ |
| SERVER_ADDRESS | http://localhost:8080 | The address and port at which the TGI server is available. |
| MODEL_ID | meta-llama/Llama-2-7b-chat-hf | Model ID used in the TGI server instance. |
| MAX_INPUT_LENGTH | 1024 | Maximum input length supported by the TGI server. |
| MAX_OUTPUT_LENGTH | 1024 | Maximum output length supported by the TGI server. |
| TOTAL_SAMPLE_COUNT | 2048 | Number of samples to run. |
| MAX_CONCURRENT_REQUESTS | 256 | The number of requests sent simultaneously to the TGI server. |
</div>

View File

@ -0,0 +1,4 @@
huggingface_hub==0.20.3
requests==2.31.0
datasets==2.18.0
transformers>=4.37.0

View File

@ -0,0 +1,90 @@
import argparse
import requests
import time
from typing import List
from datasets import load_dataset
from transformers import AutoTokenizer
from tgi_client import TgiClient
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--server_address", type=str, default="http://localhost:8080", help="Address of the TGI server"
)
parser.add_argument(
"--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model id used in TGI server"
)
parser.add_argument(
"--max_input_length", type=int, default=1024, help="Max input length for TGI model"
)
parser.add_argument(
"--max_output_length", type=int, default=1024, help="Max output length for TGI model"
)
parser.add_argument(
"--total_sample_count", type=int, default=2048, help="Total number of samples to generate"
)
parser.add_argument(
"--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests"
)
return parser.parse_args()
def read_dataset(
max_input_length: int,
total_sample_count: int,
model_id: str
) -> List[str]:
"""
Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked
and filters out too long samples.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("DIBT/10k_prompts_ranked", split="train", trust_remote_code=True)
dataset = dataset.filter(
lambda x: len(tokenizer(x["prompt"])["input_ids"]) < max_input_length
)
if len(dataset) > total_sample_count:
dataset = dataset.select(range(total_sample_count))
dataset = dataset.shuffle()
return [sample["prompt"] for sample in dataset]
def is_tgi_available(
server_address: str
) -> bool:
"""
Checks if TGI server is available under the specified address.
"""
try:
info = requests.get(f"{server_address}/info")
return info.status_code == 200
except:
return False
def main():
args = get_args()
dataset = read_dataset(
args.max_input_length, args.total_sample_count, args.model_id
)
if not is_tgi_available(args.server_address):
raise RuntimeError("Cannot connect with TGI server!")
tgi_client = TgiClient(
args.server_address, args.max_concurrent_requests
)
timestamp = time.perf_counter_ns()
tgi_client.run_generation(
dataset, args.max_output_length
)
duration_s = (time.perf_counter_ns() - timestamp) * 1e-9
tgi_client.print_performance_metrics(duration_s)
if __name__ == '__main__':
main()

100
examples/tgi_client.py Normal file
View File

@ -0,0 +1,100 @@
import os
import statistics
import threading
import time
import tqdm
from typing import List
from huggingface_hub import InferenceClient
def except_hook(args):
print(f"Thread failed with error: {args.exc_value}")
os._exit(1)
threading.excepthook = except_hook
class TgiClient:
def __init__(
self,
server_address: str,
max_num_threads: int
) -> None:
self._lock = threading.Lock()
self._semaphore = threading.Semaphore(max_num_threads)
self._client = InferenceClient(server_address)
self._ttft = []
self._tpot = []
self._generated_tokens = []
def run_generation(
self,
samples: List[str],
max_new_tokens: int
) -> None:
"""
Run generation for every sample in dataset.
Creates a separate thread for every sample.
"""
threads: List[threading.Thread] = []
for sample in tqdm.tqdm(samples):
self._semaphore.acquire()
threads.append(
threading.Thread(
target=self._process_sample, args=[sample, max_new_tokens]
)
)
threads[-1].start()
for thread in threads:
if thread is not None:
thread.join()
def _process_sample(
self,
sample: str,
max_new_tokens: int
) -> None:
"""
Generates response stream for a single sample.
Collects performance metrics.
"""
timestamp = time.perf_counter_ns()
response_stream = self._client.text_generation(
sample, max_new_tokens=max_new_tokens, stream=True, details=True
)
out = ''
for id, response in enumerate(response_stream):
if id == 0:
self._ttft.append(time.perf_counter_ns() - timestamp)
else:
self._tpot.append(time.perf_counter_ns() - timestamp)
timestamp = time.perf_counter_ns()
out += response.token.text
if response.details:
self._generated_tokens.append(response.details.generated_tokens)
self._semaphore.release()
def print_performance_metrics(
self,
duration_s: float
) -> None:
def line():
print(32*"-")
line()
print("----- Performance summary -----")
line()
print(f"Throughput: {sum(self._generated_tokens) / duration_s:.1f} tokens/s")
print(f"Throughput: {len(self._generated_tokens) / duration_s:.1f} queries/s")
line()
print(f"First token latency:")
print(f"\tMedian: \t{statistics.median(self._ttft)*1e-6:.2f}ms")
print(f"\tAverage: \t{statistics.fmean(self._ttft)*1e-6:.2f}ms")
line()
print(f"Output token latency:")
print(f"\tMedian: \t{statistics.median(self._tpot)*1e-6:.2f}ms")
print(f"\tAverage: \t{statistics.fmean(self._tpot)*1e-6:.2f}ms")
line()