# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. import os import psutil import signal import sys import typer from pathlib import Path from loguru import logger from typing import Optional from enum import Enum from huggingface_hub import hf_hub_download app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" gptq = "gptq" class Dtype(str, Enum): float16 = "float16" bloat16 = "bfloat16" @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, ): if sharded: assert ( os.getenv("WORLD_SIZE", None) is not None ), "WORLD_SIZE must be set when sharded is True" assert ( os.getenv("MASTER_ADDR", None) is not None ), "MASTER_ADDR must be set when sharded is True" assert ( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" # Remove default handler logger.remove() logger.add( sys.stdout, format="{message}", filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, ) # Import here after the logger is added to log potential import exceptions from text_generation_server import server from text_generation_server.tracing import setup_tracing # Setup OpenTelemetry distributed tracing if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = "bfloat16" if dtype is None else dtype.value logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) if sharded: tgi_file = Path(__file__).resolve().parent / "tgi_service.py" num_shard = int(os.getenv("WORLD_SIZE", "1")) logger.info("CLI SHARDED = {}".format(num_shard)) import subprocess cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" if speculate is not None: cmd += f"--speculate {speculate}" logger.info("CLI server start deepspeed ={} ".format(cmd)) sys.stdout.flush() sys.stderr.flush() with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: do_terminate = False current_handler = signal.getsignal(signal.SIGTERM) def terminate_handler(sig, frame): nonlocal do_terminate do_terminate = True if callable(current_handler): current_handler(sig, frame) signal.signal(signal.SIGTERM, terminate_handler) finished = False while not finished: try: if do_terminate: parent = psutil.Process(proc.pid) all_procs = parent.children(recursive=True) + [parent] for p in all_procs: try: p.terminate() except psutil.NoSuchProcess: pass _, alive = psutil.wait_procs(all_procs, timeout=30) for p in alive: p.kill() do_terminate = False proc.wait(timeout=3) except subprocess.TimeoutExpired: pass else: finished = True sys.stdout.flush() sys.stderr.flush() if proc.returncode != 0: logger.error(f"{cmd} exited with status = {proc.returncode}") return proc.returncode else: server.serve( model_id, revision, sharded, speculate, dtype, trust_remote_code, uds_path ) @app.command() def download_weights( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", auto_convert: bool = True, logger_level: str = "INFO", json_output: bool = False, trust_remote_code: bool = False, ): # Remove default handler logger.remove() logger.add( sys.stdout, format="{message}", filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, ) # Import here after the logger is added to log potential import exceptions from text_generation_server import utils # Test if files were already download try: utils.weight_files(model_id, revision, extension) logger.info("Files are already present on the host. " "Skipping download.") return # Local files not found except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): pass is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( "WEIGHTS_CACHE_OVERRIDE", None ) is not None if not is_local_model: try: adapter_config_filename = hf_hub_download( model_id, revision=revision, filename="adapter_config.json" ) utils.download_and_unload_peft( model_id, revision, trust_remote_code=trust_remote_code ) is_local_model = True utils.weight_files(model_id, revision, extension) return except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass try: import json medusa_head = hf_hub_download( model_id, revision=revision, filename="medusa_lm_head.pt" ) if auto_convert: medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors") if not medusa_sf.exists(): utils.convert_files([Path(medusa_head)], [medusa_sf], []) medusa_config = hf_hub_download( model_id, revision=revision, filename="config.json" ) with open(medusa_config, "r") as f: config = json.load(f) model_id = config["base_model_name_or_path"] revision = "main" try: utils.weight_files(model_id, revision, extension) logger.info( f"Files for parent {model_id} are already present on the host. " "Skipping download." ) return # Local files not found except ( utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError, ): pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass # Try to download weights from the hub try: filenames = utils.weight_hub_files(model_id, revision, extension) utils.download_weights(filenames, model_id, revision) # Successfully downloaded weights return # No weights found on the hub with this extension except utils.EntryNotFoundError as e: # Check if we want to automatically convert to safetensors or if we can use .bin weights instead if not extension == ".safetensors" or not auto_convert: raise e elif (Path(model_id) / "medusa_lm_head.pt").exists(): # Try to load as a local Medusa model try: import json medusa_head = Path(model_id) / "medusa_lm_head.pt" if auto_convert: medusa_sf = Path(model_id) / "medusa_lm_head.safetensors" if not medusa_sf.exists(): utils.convert_files([Path(medusa_head)], [medusa_sf], []) medusa_config = Path(model_id) / "config.json" with open(medusa_config, "r") as f: config = json.load(f) model_id = config["base_model_name_or_path"] revision = "main" try: utils.weight_files(model_id, revision, extension) logger.info( f"Files for parent {model_id} are already present on the host. " "Skipping download." ) return # Local files not found except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: utils.download_and_unload_peft( model_id, revision, trust_remote_code=trust_remote_code ) utils.weight_files(model_id, revision, extension) return except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE local_pt_files = utils.weight_files(model_id, revision, ".bin") # No local pytorch weights except utils.LocalEntryNotFoundError: if extension == ".safetensors": logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. " f"Downloading PyTorch weights." ) # Try to see if there are pytorch weights on the hub pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") # Download pytorch weights local_pt_files = utils.download_weights(pt_filenames, model_id, revision) if auto_convert: logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. " f"Converting PyTorch weights to safetensors." ) # Safetensors final filenames local_st_files = [p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files] try: import transformers from transformers import AutoConfig config = AutoConfig.from_pretrained( model_id, revision=revision, ) architecture = config.architectures[0] class_ = getattr(transformers, architecture) # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) except Exception: discard_names = [] # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files, discard_names) @app.command() def quantize( model_id: str, output_dir: str, revision: Optional[str] = None, logger_level: str = "INFO", json_output: bool = False, trust_remote_code: bool = False, upload_to_model_id: Optional[str] = None, percdamp: float = 0.01, act_order: bool = False, ): download_weights( model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output, ) from text_generation_server.utils.gptq.quantize import quantize quantize( model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, ) if __name__ == "__main__": app()