text-generation-inference/examples/run_generation.py
yuanwu2017 8d84ffabf2
Upgrade to SynapseAI 1.18 (#227)
Signed-off-by: yuanwu <yuan.wu@intel.com>
Co-authored-by: Thanaji Rao Thakkalapelli <tthakkalapelli@habana.ai>
2024-10-31 20:14:44 +01:00

98 lines
2.8 KiB
Python

# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import argparse
import requests
import time
from typing import List
from datasets import load_dataset
from transformers import AutoTokenizer
from tgi_client import TgiClient
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--server_address", type=str, default="http://localhost:8080", help="Address of the TGI server"
)
parser.add_argument(
"--model_id", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Model id used in TGI server"
)
parser.add_argument(
"--max_input_length", type=int, default=1024, help="Max input length for TGI model"
)
parser.add_argument(
"--max_output_length", type=int, default=1024, help="Max output length for TGI model"
)
parser.add_argument(
"--total_sample_count", type=int, default=2048, help="Total number of samples to generate"
)
parser.add_argument(
"--max_concurrent_requests", type=int, default=256, help="Max number of concurrent requests"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for datasets"
)
return parser.parse_args()
def read_dataset(
max_input_length: int,
total_sample_count: int,
model_id: str,
seed: int,
) -> List[str]:
"""
Loads public dataset from HF: https://huggingface.co/datasets/DIBT/10k_prompts_ranked
and filters out too long samples.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("DIBT/10k_prompts_ranked", split="train", trust_remote_code=True)
dataset = dataset.filter(
lambda x: len(tokenizer(x["prompt"])["input_ids"]) < max_input_length
)
if len(dataset) > total_sample_count:
dataset = dataset.select(range(total_sample_count))
dataset = dataset.shuffle(seed=seed)
return [sample["prompt"] for sample in dataset]
def is_tgi_available(
server_address: str
) -> bool:
"""
Checks if TGI server is available under the specified address.
"""
try:
info = requests.get(f"{server_address}/info")
return info.status_code == 200
except:
return False
def main():
args = get_args()
dataset = read_dataset(
args.max_input_length, args.total_sample_count, args.model_id, args.seed
)
if not is_tgi_available(args.server_address):
raise RuntimeError("Cannot connect with TGI server!")
tgi_client = TgiClient(
args.server_address, args.max_concurrent_requests
)
timestamp = time.perf_counter_ns()
tgi_client.run_generation(
dataset, args.max_output_length
)
duration_s = (time.perf_counter_ns() - timestamp) * 1e-9
tgi_client.print_performance_metrics(duration_s)
if __name__ == '__main__':
main()