mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
430 lines
14 KiB
Python
430 lines
14 KiB
Python
|
import os
|
||
|
import psutil
|
||
|
import signal
|
||
|
import sys
|
||
|
import typer
|
||
|
|
||
|
from pathlib import Path
|
||
|
from loguru import logger
|
||
|
from typing import Optional
|
||
|
from enum import Enum
|
||
|
from huggingface_hub import hf_hub_download
|
||
|
from text_generation_server.utils.adapter import parse_lora_adapters
|
||
|
|
||
|
|
||
|
app = typer.Typer()
|
||
|
|
||
|
|
||
|
class Quantization(str, Enum):
|
||
|
bitsandbytes = "bitsandbytes"
|
||
|
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||
|
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||
|
gptq = "gptq"
|
||
|
awq = "awq"
|
||
|
eetq = "eetq"
|
||
|
exl2 = "exl2"
|
||
|
fp8 = "fp8"
|
||
|
marlin = "marlin"
|
||
|
|
||
|
|
||
|
class Dtype(str, Enum):
|
||
|
float16 = "float16"
|
||
|
bloat16 = "bfloat16"
|
||
|
|
||
|
|
||
|
@app.command()
|
||
|
def serve(
|
||
|
model_id: str,
|
||
|
revision: Optional[str] = None,
|
||
|
sharded: bool = False,
|
||
|
quantize: Optional[Quantization] = None,
|
||
|
speculate: Optional[int] = None,
|
||
|
dtype: Optional[Dtype] = None,
|
||
|
trust_remote_code: bool = False,
|
||
|
uds_path: Path = "/tmp/text-generation-server",
|
||
|
logger_level: str = "INFO",
|
||
|
json_output: bool = False,
|
||
|
otlp_endpoint: Optional[str] = None,
|
||
|
otlp_service_name: str = "text-generation-inference.server",
|
||
|
max_input_tokens: Optional[int] = None,
|
||
|
):
|
||
|
if sharded:
|
||
|
# assert (
|
||
|
# os.getenv("RANK", None) is not None
|
||
|
# ), "RANK must be set when sharded is True"
|
||
|
assert (
|
||
|
os.getenv("WORLD_SIZE", None) is not None
|
||
|
), "WORLD_SIZE must be set when sharded is True"
|
||
|
assert (
|
||
|
os.getenv("MASTER_ADDR", None) is not None
|
||
|
), "MASTER_ADDR must be set when sharded is True"
|
||
|
assert (
|
||
|
os.getenv("MASTER_PORT", None) is not None
|
||
|
), "MASTER_PORT must be set when sharded is True"
|
||
|
|
||
|
# Remove default handler
|
||
|
logger.remove()
|
||
|
logger.add(
|
||
|
sys.stdout,
|
||
|
format="{message}",
|
||
|
filter="text_generation_server",
|
||
|
level=logger_level,
|
||
|
serialize=json_output,
|
||
|
backtrace=True,
|
||
|
diagnose=False,
|
||
|
)
|
||
|
|
||
|
# Import here after the logger is added to log potential import exceptions
|
||
|
from text_generation_server import server
|
||
|
from text_generation_server.tracing import setup_tracing
|
||
|
|
||
|
# Setup OpenTelemetry distributed tracing
|
||
|
if otlp_endpoint is not None:
|
||
|
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||
|
|
||
|
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||
|
|
||
|
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||
|
# and warn the user
|
||
|
if lora_adapters:
|
||
|
logger.warning("LoRA adapters enabled (experimental feature).")
|
||
|
|
||
|
if "CUDA_GRAPHS" in os.environ:
|
||
|
logger.warning(
|
||
|
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
|
||
|
)
|
||
|
global CUDA_GRAPHS
|
||
|
CUDA_GRAPHS = None
|
||
|
|
||
|
# Downgrade enum into str for easier management later on
|
||
|
quantize = None if quantize is None else quantize.value
|
||
|
dtype = "bfloat16" if dtype is None else dtype.value
|
||
|
logger.info(f"quantize={quantize}")
|
||
|
if dtype is not None and quantize not in {
|
||
|
None,
|
||
|
"bitsandbytes",
|
||
|
"bitsandbytes-nf4",
|
||
|
"bitsandbytes-fp4",
|
||
|
}:
|
||
|
raise RuntimeError(
|
||
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||
|
)
|
||
|
|
||
|
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||
|
|
||
|
if sharded:
|
||
|
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||
|
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||
|
logger.info("CLI SHARDED = {}".format(num_shard))
|
||
|
import subprocess
|
||
|
|
||
|
cmd = (
|
||
|
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||
|
)
|
||
|
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||
|
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||
|
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
||
|
if speculate is not None:
|
||
|
cmd += f"--speculate {speculate}"
|
||
|
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||
|
sys.stdout.flush()
|
||
|
sys.stderr.flush()
|
||
|
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||
|
do_terminate = False
|
||
|
current_handler = signal.getsignal(signal.SIGTERM)
|
||
|
|
||
|
def terminate_handler(sig, frame):
|
||
|
nonlocal do_terminate
|
||
|
do_terminate = True
|
||
|
if callable(current_handler):
|
||
|
current_handler(sig, frame)
|
||
|
|
||
|
signal.signal(signal.SIGTERM, terminate_handler)
|
||
|
|
||
|
finished = False
|
||
|
while not finished:
|
||
|
try:
|
||
|
if do_terminate:
|
||
|
parent = psutil.Process(proc.pid)
|
||
|
all_procs = parent.children(recursive=True) + [parent]
|
||
|
for p in all_procs:
|
||
|
try:
|
||
|
p.terminate()
|
||
|
except psutil.NoSuchProcess:
|
||
|
pass
|
||
|
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||
|
for p in alive:
|
||
|
p.kill()
|
||
|
|
||
|
do_terminate = False
|
||
|
|
||
|
proc.wait(timeout=3)
|
||
|
except subprocess.TimeoutExpired:
|
||
|
pass
|
||
|
else:
|
||
|
finished = True
|
||
|
|
||
|
sys.stdout.flush()
|
||
|
sys.stderr.flush()
|
||
|
if proc.returncode != 0:
|
||
|
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||
|
return proc.returncode
|
||
|
else:
|
||
|
server.serve(
|
||
|
model_id,
|
||
|
lora_adapters,
|
||
|
revision,
|
||
|
sharded,
|
||
|
quantize,
|
||
|
speculate,
|
||
|
dtype,
|
||
|
trust_remote_code,
|
||
|
uds_path,
|
||
|
max_input_tokens,
|
||
|
)
|
||
|
|
||
|
|
||
|
@app.command()
|
||
|
def download_weights(
|
||
|
model_id: str,
|
||
|
revision: Optional[str] = None,
|
||
|
extension: str = ".safetensors",
|
||
|
auto_convert: bool = True,
|
||
|
logger_level: str = "INFO",
|
||
|
json_output: bool = False,
|
||
|
trust_remote_code: bool = False,
|
||
|
merge_lora: bool = False,
|
||
|
):
|
||
|
# Remove default handler
|
||
|
logger.remove()
|
||
|
logger.add(
|
||
|
sys.stdout,
|
||
|
format="{message}",
|
||
|
filter="text_generation_server",
|
||
|
level=logger_level,
|
||
|
serialize=json_output,
|
||
|
backtrace=True,
|
||
|
diagnose=False,
|
||
|
)
|
||
|
|
||
|
# Import here after the logger is added to log potential import exceptions
|
||
|
from text_generation_server import utils
|
||
|
|
||
|
# Test if files were already download
|
||
|
try:
|
||
|
utils.weight_files(model_id, revision, extension)
|
||
|
logger.info("Files are already present on the host. " "Skipping download.")
|
||
|
return
|
||
|
# Local files not found
|
||
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||
|
pass
|
||
|
|
||
|
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||
|
) is not None
|
||
|
|
||
|
if not is_local_model:
|
||
|
# TODO: maybe reverse the default value of merge_lora?
|
||
|
# currently by default we don't merge the weights with the base model
|
||
|
if merge_lora:
|
||
|
try:
|
||
|
hf_hub_download(
|
||
|
model_id, revision=revision, filename="adapter_config.json"
|
||
|
)
|
||
|
utils.download_and_unload_peft(
|
||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||
|
)
|
||
|
is_local_model = True
|
||
|
utils.weight_files(model_id, revision, extension)
|
||
|
return
|
||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||
|
pass
|
||
|
else:
|
||
|
try:
|
||
|
utils.peft.download_peft(
|
||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||
|
)
|
||
|
except Exception:
|
||
|
pass
|
||
|
|
||
|
try:
|
||
|
import json
|
||
|
|
||
|
config = hf_hub_download(
|
||
|
model_id, revision=revision, filename="config.json"
|
||
|
)
|
||
|
with open(config, "r") as f:
|
||
|
config = json.load(f)
|
||
|
|
||
|
base_model_id = config.get("base_model_name_or_path", None)
|
||
|
if base_model_id and base_model_id != model_id:
|
||
|
try:
|
||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||
|
download_weights(
|
||
|
model_id=base_model_id,
|
||
|
revision="main",
|
||
|
extension=extension,
|
||
|
auto_convert=auto_convert,
|
||
|
logger_level=logger_level,
|
||
|
json_output=json_output,
|
||
|
trust_remote_code=trust_remote_code,
|
||
|
)
|
||
|
except Exception:
|
||
|
pass
|
||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||
|
pass
|
||
|
|
||
|
# Try to download weights from the hub
|
||
|
try:
|
||
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||
|
utils.download_weights(filenames, model_id, revision)
|
||
|
# Successfully downloaded weights
|
||
|
return
|
||
|
|
||
|
# No weights found on the hub with this extension
|
||
|
except utils.EntryNotFoundError as e:
|
||
|
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
|
||
|
if not extension == ".safetensors" or not auto_convert:
|
||
|
raise e
|
||
|
|
||
|
elif (Path(model_id) / "adapter_config.json").exists():
|
||
|
# Try to load as a local PEFT model
|
||
|
try:
|
||
|
utils.download_and_unload_peft(
|
||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||
|
)
|
||
|
utils.weight_files(model_id, revision, extension)
|
||
|
return
|
||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||
|
pass
|
||
|
elif (Path(model_id) / "config.json").exists():
|
||
|
# Try to load as a local Medusa model
|
||
|
try:
|
||
|
import json
|
||
|
|
||
|
config = Path(model_id) / "config.json"
|
||
|
with open(config, "r") as f:
|
||
|
config = json.load(f)
|
||
|
|
||
|
base_model_id = config.get("base_model_name_or_path", None)
|
||
|
if base_model_id:
|
||
|
try:
|
||
|
logger.info(f"Downloading parent model {base_model_id}")
|
||
|
download_weights(
|
||
|
model_id=base_model_id,
|
||
|
revision="main",
|
||
|
extension=extension,
|
||
|
auto_convert=auto_convert,
|
||
|
logger_level=logger_level,
|
||
|
json_output=json_output,
|
||
|
trust_remote_code=trust_remote_code,
|
||
|
)
|
||
|
except Exception:
|
||
|
pass
|
||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||
|
pass
|
||
|
|
||
|
# Try to see if there are local pytorch weights
|
||
|
try:
|
||
|
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||
|
try:
|
||
|
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||
|
except Exception:
|
||
|
local_pt_files = utils.weight_files(model_id, revision, ".pt")
|
||
|
|
||
|
# No local pytorch weights
|
||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||
|
if extension == ".safetensors":
|
||
|
logger.warning(
|
||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||
|
f"Downloading PyTorch weights."
|
||
|
)
|
||
|
|
||
|
# Try to see if there are pytorch weights on the hub
|
||
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||
|
# Download pytorch weights
|
||
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||
|
|
||
|
if auto_convert:
|
||
|
if not trust_remote_code:
|
||
|
logger.warning(
|
||
|
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
|
||
|
"Pickle files are unsafe and can essentially contain remote code execution!"
|
||
|
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
|
||
|
)
|
||
|
|
||
|
logger.warning(
|
||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||
|
f"Converting PyTorch weights to safetensors."
|
||
|
)
|
||
|
|
||
|
# Safetensors final filenames
|
||
|
local_st_files = [
|
||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||
|
for p in local_pt_files
|
||
|
]
|
||
|
try:
|
||
|
import transformers
|
||
|
import json
|
||
|
|
||
|
if is_local_model:
|
||
|
config_filename = os.path.join(model_id, "config.json")
|
||
|
else:
|
||
|
config_filename = hf_hub_download(
|
||
|
model_id, revision=revision, filename="config.json"
|
||
|
)
|
||
|
with open(config_filename, "r") as f:
|
||
|
config = json.load(f)
|
||
|
architecture = config["architectures"][0]
|
||
|
|
||
|
class_ = getattr(transformers, architecture)
|
||
|
|
||
|
# Name for this varible depends on transformers version.
|
||
|
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||
|
|
||
|
except Exception:
|
||
|
discard_names = []
|
||
|
# Convert pytorch weights to safetensors
|
||
|
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||
|
|
||
|
|
||
|
@app.command()
|
||
|
def quantize(
|
||
|
model_id: str,
|
||
|
output_dir: str,
|
||
|
revision: Optional[str] = None,
|
||
|
logger_level: str = "INFO",
|
||
|
json_output: bool = False,
|
||
|
trust_remote_code: bool = False,
|
||
|
upload_to_model_id: Optional[str] = None,
|
||
|
percdamp: float = 0.01,
|
||
|
act_order: bool = False,
|
||
|
groupsize: int = 128,
|
||
|
):
|
||
|
if revision is None:
|
||
|
revision = "main"
|
||
|
download_weights(
|
||
|
model_id=model_id,
|
||
|
revision=revision,
|
||
|
logger_level=logger_level,
|
||
|
json_output=json_output,
|
||
|
)
|
||
|
from text_generation_server.layers.gptq.quantize import quantize
|
||
|
|
||
|
quantize(
|
||
|
model_id=model_id,
|
||
|
bits=4,
|
||
|
groupsize=groupsize,
|
||
|
output_dir=output_dir,
|
||
|
revision=revision,
|
||
|
trust_remote_code=trust_remote_code,
|
||
|
upload_to_model_id=upload_to_model_id,
|
||
|
percdamp=percdamp,
|
||
|
act_order=act_order,
|
||
|
sym=True,
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
app()
|