# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. import argparse import requests import time from typing import List from datasets import load_dataset from transformers import AutoTokenizer from tgi_client import TgiClient def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--server_address", type=str, default="http://localhost:8083", help="Address of the TGI server", ) parser.add_argument( "--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model id used in TGI server", ) parser.add_argument( "--max_input_length", type=int, default=1024, help="Max input length for TGI model", ) parser.add_argument( "--max_output_length", type=int, default=1024, help="Max output length for TGI model", ) parser.add_argument( "--total_sample_count", type=int, default=2048, help="Total number of samples to generate", ) parser.add_argument( "--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests", ) parser.add_argument("--seed", type=int, default=42, help="Random seed for datasets") return parser.parse_args() def read_dataset( max_input_length: int, total_sample_count: int, model_id: str, seed: int, ) -> List[str]: """ Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked and filters out too long samples. """ tokenizer = AutoTokenizer.from_pretrained(model_id) dataset = load_dataset( "DIBT/10k_prompts_ranked", split="train", trust_remote_code=True ) dataset = dataset.filter( lambda x: len(tokenizer(x["prompt"])["input_ids"]) < max_input_length ) if len(dataset) > total_sample_count: dataset = dataset.select(range(total_sample_count)) dataset = dataset.shuffle(seed=seed) return [sample["prompt"] for sample in dataset] def is_tgi_available(server_address: str) -> bool: """ Checks if TGI server is available under the specified address. """ try: info = requests.get(f"{server_address}/info") return info.status_code == 200 except Exception: return False def main(): args = get_args() dataset = read_dataset( args.max_input_length, args.total_sample_count, args.model_id, args.seed ) if not is_tgi_available(args.server_address): raise RuntimeError("Cannot connect with TGI server!") tgi_client = TgiClient(args.server_address, args.max_concurrent_requests) timestamp = time.perf_counter_ns() tgi_client.run_generation(dataset, args.max_output_length) duration_s = (time.perf_counter_ns() - timestamp) * 1e-9 tgi_client.print_performance_metrics(duration_s) if __name__ == "__main__": main()