2024-04-08 16:06:21 +00:00
|
|
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
|
|
|
|
2024-03-26 08:17:55 +00:00
|
|
|
import argparse
|
|
|
|
import requests
|
|
|
|
import time
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
from tgi_client import TgiClient
|
|
|
|
|
|
|
|
|
|
|
|
def get_args():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
2024-10-28 07:04:56 +00:00
|
|
|
"--server_address", type=str, default="http://localhost:8083", help="Address of the TGI server"
|
2024-03-26 08:17:55 +00:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model id used in TGI server"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max_input_length", type=int, default=1024, help="Max input length for TGI model"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max_output_length", type=int, default=1024, help="Max output length for TGI model"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--total_sample_count", type=int, default=2048, help="Total number of samples to generate"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests"
|
|
|
|
)
|
2024-10-31 19:14:44 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--seed", type=int, default=42, help="Random seed for datasets"
|
|
|
|
)
|
|
|
|
|
2024-03-26 08:17:55 +00:00
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
def read_dataset(
|
|
|
|
max_input_length: int,
|
|
|
|
total_sample_count: int,
|
2024-10-31 19:14:44 +00:00
|
|
|
model_id: str,
|
|
|
|
seed: int,
|
2024-03-26 08:17:55 +00:00
|
|
|
) -> List[str]:
|
|
|
|
"""
|
|
|
|
Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked
|
|
|
|
and filters out too long samples.
|
|
|
|
"""
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
|
|
|
dataset = load_dataset("DIBT/10k_prompts_ranked", split="train", trust_remote_code=True)
|
|
|
|
dataset = dataset.filter(
|
|
|
|
lambda x: len(tokenizer(x["prompt"])["input_ids"]) < max_input_length
|
|
|
|
)
|
|
|
|
if len(dataset) > total_sample_count:
|
|
|
|
dataset = dataset.select(range(total_sample_count))
|
2024-11-01 03:24:40 +00:00
|
|
|
|
2024-10-31 19:14:44 +00:00
|
|
|
dataset = dataset.shuffle(seed=seed)
|
2024-03-26 08:17:55 +00:00
|
|
|
return [sample["prompt"] for sample in dataset]
|
|
|
|
|
|
|
|
|
|
|
|
def is_tgi_available(
|
|
|
|
server_address: str
|
|
|
|
) -> bool:
|
|
|
|
"""
|
|
|
|
Checks if TGI server is available under the specified address.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
info = requests.get(f"{server_address}/info")
|
|
|
|
return info.status_code == 200
|
|
|
|
except:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = get_args()
|
|
|
|
dataset = read_dataset(
|
2024-10-31 19:14:44 +00:00
|
|
|
args.max_input_length, args.total_sample_count, args.model_id, args.seed
|
2024-03-26 08:17:55 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if not is_tgi_available(args.server_address):
|
|
|
|
raise RuntimeError("Cannot connect with TGI server!")
|
|
|
|
|
|
|
|
tgi_client = TgiClient(
|
|
|
|
args.server_address, args.max_concurrent_requests
|
|
|
|
)
|
|
|
|
timestamp = time.perf_counter_ns()
|
|
|
|
tgi_client.run_generation(
|
|
|
|
dataset, args.max_output_length
|
|
|
|
)
|
|
|
|
duration_s = (time.perf_counter_ns() - timestamp) * 1e-9
|
|
|
|
tgi_client.print_performance_metrics(duration_s)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|