From 9796b0e10d3b7e2279bc0c88ef40bcad4d1c0efb Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 26 Mar 2024 09:17:55 +0100 Subject: [PATCH] Add simple continuous batching benchmark (#108) Co-authored-by: Karol Damaszke Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- README.md | 6 ++- examples/README.md | 39 +++++++++++++++ examples/requirements.txt | 4 ++ examples/run_generation.py | 90 +++++++++++++++++++++++++++++++++ examples/tgi_client.py | 100 +++++++++++++++++++++++++++++++++++++ 5 files changed, 237 insertions(+), 2 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/requirements.txt create mode 100644 examples/run_generation.py create mode 100644 examples/tgi_client.py diff --git a/README.md b/README.md index a51f485d..843117d8 100644 --- a/README.md +++ b/README.md @@ -55,20 +55,22 @@ To use [🤗 text-generation-inference](https://github.com/huggingface/text-gene docker run -p 8080:80 -v $volume:/data --runtime=habana -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:1.2.1 --model-id $model --sharded true --num-shard 8 ``` -4. You can then send a simple request: +3. You can then send a simple request: ```bash curl 127.0.0.1:8080/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":32}}' \ -H 'Content-Type: application/json' ``` -5. To run static benchmark test, please refer to [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark). +4. To run static benchmark test, please refer to [TGI's benchmark tool](https://github.com/huggingface/text-generation-inference/tree/main/benchmark). To run it on the same machine, you can do the following: * `docker exec -it bash` , pick the docker started from step 2 using docker ps * `text-generation-benchmark -t ` , pass the model-id from docker run command * after the completion of tests, hit ctrl+c to see the performance data summary. +5. To run continuous batching test, please refer to [examples](https://github.com/huggingface/tgi-gaudi/tree/habana-main/examples). + ## Adjusting TGI parameters Maximum sequence length is controlled by two arguments: diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..93f391ec --- /dev/null +++ b/examples/README.md @@ -0,0 +1,39 @@ +# TGI-Gaudi example + +This example provide a simple way of usage of `tgi-gaudi` with continuous batching. It uses a small dataset [DIBT/10k_prompts_ranked](https://huggingface.co/datasets/DIBT/10k_prompts_ranked) and present basic performance numbers. + +## Get started + +### Install + +``` +pip install -r requirements +``` + +### Setup TGI server + +More details on runing the TGI server available [here](https://github.com/huggingface/tgi-gaudi/blob/habana-main/README.md#running-tgi-on-gaudi). + +### Run benchmark + +To run benchmark use below command: + +``` +python run_generation --model_id MODEL_ID +``` +where `MODEL_ID` should be set to the same value as in the TGI server instance. +> For gated models such as [LLama](https://huggingface.co/meta-llama) or [StarCoder](https://huggingface.co/bigcode/starcoder), you will have to set environment variable `HUGGING_FACE_HUB_TOKEN=` with a valid Hugging Face Hub read token. + +All possible parameters are described in the below table: +
+ +| Name | Default value | Description | +| ------------------------- | :---------------------------- | :------------------------------------------------------------ | +| SERVER_ADDRESS | http://localhost:8080 | The address and port at which the TGI server is available. | +| MODEL_ID | meta-llama/Llama-2-7b-chat-hf | Model ID used in the TGI server instance. | +| MAX_INPUT_LENGTH | 1024 | Maximum input length supported by the TGI server. | +| MAX_OUTPUT_LENGTH | 1024 | Maximum output length supported by the TGI server. | +| TOTAL_SAMPLE_COUNT | 2048 | Number of samples to run. | +| MAX_CONCURRENT_REQUESTS | 256 | The number of requests sent simultaneously to the TGI server. | + +
\ No newline at end of file diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 00000000..e98dbf64 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,4 @@ +huggingface_hub==0.20.3 +requests==2.31.0 +datasets==2.18.0 +transformers>=4.37.0 \ No newline at end of file diff --git a/examples/run_generation.py b/examples/run_generation.py new file mode 100644 index 00000000..c31ebeef --- /dev/null +++ b/examples/run_generation.py @@ -0,0 +1,90 @@ +import argparse +import requests +import time +from typing import List + +from datasets import load_dataset +from transformers import AutoTokenizer + +from tgi_client import TgiClient + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--server_address", type=str, default="http://localhost:8080", help="Address of the TGI server" + ) + parser.add_argument( + "--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model id used in TGI server" + ) + parser.add_argument( + "--max_input_length", type=int, default=1024, help="Max input length for TGI model" + ) + parser.add_argument( + "--max_output_length", type=int, default=1024, help="Max output length for TGI model" + ) + parser.add_argument( + "--total_sample_count", type=int, default=2048, help="Total number of samples to generate" + ) + parser.add_argument( + "--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests" + ) + return parser.parse_args() + + +def read_dataset( + max_input_length: int, + total_sample_count: int, + model_id: str +) -> List[str]: + """ + Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked + and filters out too long samples. + """ + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("DIBT/10k_prompts_ranked", split="train", trust_remote_code=True) + dataset = dataset.filter( + lambda x: len(tokenizer(x["prompt"])["input_ids"]) < max_input_length + ) + if len(dataset) > total_sample_count: + dataset = dataset.select(range(total_sample_count)) + dataset = dataset.shuffle() + return [sample["prompt"] for sample in dataset] + + +def is_tgi_available( + server_address: str +) -> bool: + """ + Checks if TGI server is available under the specified address. + """ + try: + info = requests.get(f"{server_address}/info") + return info.status_code == 200 + except: + return False + + +def main(): + args = get_args() + dataset = read_dataset( + args.max_input_length, args.total_sample_count, args.model_id + ) + + if not is_tgi_available(args.server_address): + raise RuntimeError("Cannot connect with TGI server!") + + tgi_client = TgiClient( + args.server_address, args.max_concurrent_requests + ) + timestamp = time.perf_counter_ns() + tgi_client.run_generation( + dataset, args.max_output_length + ) + duration_s = (time.perf_counter_ns() - timestamp) * 1e-9 + tgi_client.print_performance_metrics(duration_s) + + +if __name__ == '__main__': + main() diff --git a/examples/tgi_client.py b/examples/tgi_client.py new file mode 100644 index 00000000..1f627068 --- /dev/null +++ b/examples/tgi_client.py @@ -0,0 +1,100 @@ +import os +import statistics +import threading +import time +import tqdm +from typing import List + +from huggingface_hub import InferenceClient + + +def except_hook(args): + print(f"Thread failed with error: {args.exc_value}") + os._exit(1) + +threading.excepthook = except_hook + + +class TgiClient: + def __init__( + self, + server_address: str, + max_num_threads: int + ) -> None: + self._lock = threading.Lock() + self._semaphore = threading.Semaphore(max_num_threads) + self._client = InferenceClient(server_address) + + self._ttft = [] + self._tpot = [] + self._generated_tokens = [] + + def run_generation( + self, + samples: List[str], + max_new_tokens: int + ) -> None: + """ + Run generation for every sample in dataset. + Creates a separate thread for every sample. + """ + threads: List[threading.Thread] = [] + for sample in tqdm.tqdm(samples): + self._semaphore.acquire() + threads.append( + threading.Thread( + target=self._process_sample, args=[sample, max_new_tokens] + ) + ) + threads[-1].start() + for thread in threads: + if thread is not None: + thread.join() + + def _process_sample( + self, + sample: str, + max_new_tokens: int + ) -> None: + """ + Generates response stream for a single sample. + Collects performance metrics. + """ + timestamp = time.perf_counter_ns() + response_stream = self._client.text_generation( + sample, max_new_tokens=max_new_tokens, stream=True, details=True + ) + out = '' + for id, response in enumerate(response_stream): + if id == 0: + self._ttft.append(time.perf_counter_ns() - timestamp) + else: + self._tpot.append(time.perf_counter_ns() - timestamp) + timestamp = time.perf_counter_ns() + out += response.token.text + if response.details: + self._generated_tokens.append(response.details.generated_tokens) + + self._semaphore.release() + + def print_performance_metrics( + self, + duration_s: float + ) -> None: + def line(): + print(32*"-") + + line() + print("----- Performance summary -----") + line() + print(f"Throughput: {sum(self._generated_tokens) / duration_s:.1f} tokens/s") + print(f"Throughput: {len(self._generated_tokens) / duration_s:.1f} queries/s") + line() + print(f"First token latency:") + print(f"\tMedian: \t{statistics.median(self._ttft)*1e-6:.2f}ms") + print(f"\tAverage: \t{statistics.fmean(self._ttft)*1e-6:.2f}ms") + line() + print(f"Output token latency:") + print(f"\tMedian: \t{statistics.median(self._tpot)*1e-6:.2f}ms") + print(f"\tAverage: \t{statistics.fmean(self._tpot)*1e-6:.2f}ms") + line()