mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
(ADD): client benchmarking utility & omegaconf
This commit is contained in:
parent
20ee71dcf5
commit
1455830909
90
clients/python/utils/client.py
Normal file
90
clients/python/utils/client.py
Normal file
@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from text_generation import AsyncClient
|
||||
|
||||
from typing import List, Dict
|
||||
from torch import IntTensor
|
||||
from transformers import AutoTokenizer
|
||||
import numpy as np
|
||||
|
||||
from hydra import (
|
||||
compose,
|
||||
initialize
|
||||
)
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# intialize Hydra subsystem
|
||||
initialize(version_base=None, config_path="conf")
|
||||
cfg: DictConfig = compose("config.yaml")
|
||||
print(f"Experiment configuration:\n{OmegaConf.to_yaml(cfg)}")
|
||||
|
||||
# serving system deployment
|
||||
max_concurrent_requests: int = cfg.deployment.max_concurrent_requests
|
||||
endpoint: str = ""
|
||||
if cfg.deployment.local: endpoint = f"{cfg.deployment.addr}:{ cfg.deployment.port}"
|
||||
if not cfg.deployment.local: endpoint = cfg.deployment.endpoint
|
||||
client = AsyncClient(endpoint)
|
||||
|
||||
# uncomment for using natural language prompts
|
||||
# prompt: str = "Hello" * 100
|
||||
# output_lenght (decoding length)
|
||||
max_new_tokens: int = cfg.GenerationConfig.max_new_tokens
|
||||
# generation_strategy
|
||||
repetition_penalty: float = cfg.GenerationConfig.repetition_penalty
|
||||
do_sample: bool = cfg.GenerationConfig.do_sample
|
||||
# generation finish reason
|
||||
stop_sequences: List[str] = [stop for stop in cfg.GenerationConfig.stop_sequences]
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.macros.tokenizer_name)
|
||||
# "id": 21820 = "Hello"
|
||||
dummy_input: IntTensor = IntTensor([[21820]])
|
||||
# inverse tokenization (sanity check)
|
||||
# token: str = tokenizer.decode(dummy_input[0], skip_special_tokens=True)
|
||||
sequence_length: int = cfg.GenerationConfig.sequence_length
|
||||
multiple_dummy_repeat = dummy_input.repeat(1, sequence_length)
|
||||
prompt: str = tokenizer.decode(multiple_dummy_repeat[0], skip_special_tokens=True)
|
||||
assert np.shape(multiple_dummy_repeat)[-1] == sequence_length
|
||||
|
||||
generate_kwargs: Dict = {"do_sample": do_sample,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"stop_sequences": stop_sequences,
|
||||
"decoder_input_details": True,
|
||||
}
|
||||
|
||||
prompts: List = [prompt] * max_concurrent_requests
|
||||
# create many coroutines
|
||||
coros = [client.generate(prompt, **generate_kwargs) for prompt in prompts]
|
||||
|
||||
async def batch():
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
st: float = time.perf_counter_ns()
|
||||
results = asyncio.run(batch())
|
||||
et = (time.perf_counter_ns() - st) * 1e-9
|
||||
print(f"Serving elapsed time: {et:0.4f} seconds")
|
||||
# check the last response
|
||||
print(results[-1].details)
|
||||
|
||||
total_input_sequence_tokens: int = 0
|
||||
total_decoded_tokens: int = 0
|
||||
|
||||
for prompt, response in zip(prompts,results):
|
||||
# uncomment for see generations
|
||||
# print(prompt + response.generated_text)
|
||||
# assert np.shape(tokenizer(response.generated_text, return_tensors="pt").input_ids)[-1] -1 == max_new_tokens
|
||||
# assert response.details.generated_tokens == max_new_tokens
|
||||
total_input_sequence_tokens += len(response.details.prefill)
|
||||
total_decoded_tokens += response.details.generated_tokens
|
||||
assert total_input_sequence_tokens == max_concurrent_requests * sequence_length
|
||||
|
||||
# stats
|
||||
print(f"Serving elapsed time: {et:0.4f} seconds")
|
||||
print(f"Total requests: {max_concurrent_requests}")
|
||||
print(f"Total sequences tokens: {total_input_sequence_tokens}")
|
||||
print(f"Sequence length: {sequence_length}")
|
||||
print(f"Total decoded tokens: {total_decoded_tokens}")
|
||||
print(f"Throughput: {total_decoded_tokens/et} tokens/sec")
|
||||
print(f"Total processed tokens: {total_decoded_tokens + total_input_sequence_tokens}")
|
17
clients/python/utils/conf/config.yaml
Normal file
17
clients/python/utils/conf/config.yaml
Normal file
@ -0,0 +1,17 @@
|
||||
deployment:
|
||||
addr: http://127.0.0.1
|
||||
port: 8080
|
||||
max_concurrent_requests: 64
|
||||
local: True
|
||||
endpoint: http://guanaco-65b-merged.inference.takomo.internal
|
||||
|
||||
macros:
|
||||
tokenizer_name: bigscience/bloom-560m
|
||||
|
||||
GenerationConfig:
|
||||
max_new_tokens: 100
|
||||
repetition_penalty: 1.5
|
||||
do_sample: True
|
||||
stop_sequences:
|
||||
- "length"
|
||||
sequence_length: 100
|
10
server/text_generation_server/debug_backend.py
Normal file
10
server/text_generation_server/debug_backend.py
Normal file
@ -0,0 +1,10 @@
|
||||
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve distilgpt2
|
||||
import subprocess
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
wd_dir: str = Path(__file__).parent.absolute()
|
||||
cli_path: str = os.path.join(wd_dir, "cli.py")
|
||||
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
command: str = f"python -m torch.distributed.run --nproc_per_node=1 {cli_path} serve bigscience/bloom-560m"
|
||||
subprocess.run(command.split())
|
Loading…
Reference in New Issue
Block a user