import os import statistics import threading import time import tqdm from typing import List from huggingface_hub import InferenceClient def except_hook(args): print(f"Thread failed with error: {args.exc_value}") os._exit(1) threading.excepthook = except_hook class TgiClient: def __init__( self, server_address: str, max_num_threads: int ) -> None: self._lock = threading.Lock() self._semaphore = threading.Semaphore(max_num_threads) self._client = InferenceClient(server_address) self._ttft = [] self._tpot = [] self._generated_tokens = [] def run_generation( self, samples: List[str], max_new_tokens: int ) -> None: """ Run generation for every sample in dataset. Creates a separate thread for every sample. """ threads: List[threading.Thread] = [] for sample in tqdm.tqdm(samples): self._semaphore.acquire() threads.append( threading.Thread( target=self._process_sample, args=[sample, max_new_tokens] ) ) threads[-1].start() for thread in threads: if thread is not None: thread.join() def _process_sample( self, sample: str, max_new_tokens: int ) -> None: """ Generates response stream for a single sample. Collects performance metrics. """ timestamp = time.perf_counter_ns() response_stream = self._client.text_generation( sample, max_new_tokens=max_new_tokens, stream=True, details=True ) out = '' for id, response in enumerate(response_stream): if id == 0: self._ttft.append(time.perf_counter_ns() - timestamp) else: self._tpot.append(time.perf_counter_ns() - timestamp) timestamp = time.perf_counter_ns() out += response.token.text if response.details: self._generated_tokens.append(response.details.generated_tokens) self._semaphore.release() def print_performance_metrics( self, duration_s: float ) -> None: def line(): print(32*"-") line() print("----- Performance summary -----") line() print(f"Throughput: {sum(self._generated_tokens) / duration_s:.1f} tokens/s") print(f"Throughput: {len(self._generated_tokens) / duration_s:.1f} queries/s") line() print(f"First token latency:") print(f"\tMedian: \t{statistics.median(self._ttft)*1e-6:.2f}ms") print(f"\tAverage: \t{statistics.fmean(self._ttft)*1e-6:.2f}ms") line() print(f"Output token latency:") print(f"\tMedian: \t{statistics.median(self._tpot)*1e-6:.2f}ms") print(f"\tAverage: \t{statistics.fmean(self._tpot)*1e-6:.2f}ms") line()