text-generation-inference/examples/tgi_client.py
Karol Damaszke d957e32601
Add Habana copyright header (#122)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
2024-04-08 18:06:21 +02:00

103 lines
3.0 KiB
Python

# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import statistics
import threading
import time
import tqdm
from typing import List
from huggingface_hub import InferenceClient
def except_hook(args):
print(f"Thread failed with error: {args.exc_value}")
os._exit(1)
threading.excepthook = except_hook
class TgiClient:
def __init__(
self,
server_address: str,
max_num_threads: int
) -> None:
self._lock = threading.Lock()
self._semaphore = threading.Semaphore(max_num_threads)
self._client = InferenceClient(server_address)
self._ttft = []
self._tpot = []
self._generated_tokens = []
def run_generation(
self,
samples: List[str],
max_new_tokens: int
) -> None:
"""
Run generation for every sample in dataset.
Creates a separate thread for every sample.
"""
threads: List[threading.Thread] = []
for sample in tqdm.tqdm(samples):
self._semaphore.acquire()
threads.append(
threading.Thread(
target=self._process_sample, args=[sample, max_new_tokens]
)
)
threads[-1].start()
for thread in threads:
if thread is not None:
thread.join()
def _process_sample(
self,
sample: str,
max_new_tokens: int
) -> None:
"""
Generates response stream for a single sample.
Collects performance metrics.
"""
timestamp = time.perf_counter_ns()
response_stream = self._client.text_generation(
sample, max_new_tokens=max_new_tokens, stream=True, details=True
)
out = ''
for id, response in enumerate(response_stream):
if id == 0:
self._ttft.append(time.perf_counter_ns() - timestamp)
else:
self._tpot.append(time.perf_counter_ns() - timestamp)
timestamp = time.perf_counter_ns()
out += response.token.text
if response.details:
self._generated_tokens.append(response.details.generated_tokens)
self._semaphore.release()
def print_performance_metrics(
self,
duration_s: float
) -> None:
def line():
print(32*"-")
line()
print("----- Performance summary -----")
line()
print(f"Throughput: {sum(self._generated_tokens) / duration_s:.1f} tokens/s")
print(f"Throughput: {len(self._generated_tokens) / duration_s:.1f} queries/s")
line()
print(f"First token latency:")
print(f"\tMedian: \t{statistics.median(self._ttft)*1e-6:.2f}ms")
print(f"\tAverage: \t{statistics.fmean(self._ttft)*1e-6:.2f}ms")
line()
print(f"Output token latency:")
print(f"\tMedian: \t{statistics.median(self._tpot)*1e-6:.2f}ms")
print(f"\tAverage: \t{statistics.fmean(self._tpot)*1e-6:.2f}ms")
line()