diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index d4445a13..dc31ab2f 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -1,6 +1,4 @@ import os -import psutil -import signal import sys import typer @@ -115,80 +113,19 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - - logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) - - if sharded and os.getenv("ATTENTION", "default") not in {"paged"}: - tgi_file = Path(__file__).resolve().parent / "tgi_service.py" - num_shard = int(os.getenv("WORLD_SIZE", "1")) - logger.info("CLI SHARDED = {}".format(num_shard)) - import subprocess - - cmd = ( - f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}" - ) - cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}" - cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}" - cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}" - if speculate is not None: - cmd += f"--speculate {speculate}" - logger.info("CLI server start deepspeed ={} ".format(cmd)) - sys.stdout.flush() - sys.stderr.flush() - with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - do_terminate = False - current_handler = signal.getsignal(signal.SIGTERM) - - def terminate_handler(sig, frame): - nonlocal do_terminate - do_terminate = True - if callable(current_handler): - current_handler(sig, frame) - - signal.signal(signal.SIGTERM, terminate_handler) - - finished = False - while not finished: - try: - if do_terminate: - parent = psutil.Process(proc.pid) - all_procs = parent.children(recursive=True) + [parent] - for p in all_procs: - try: - p.terminate() - except psutil.NoSuchProcess: - pass - _, alive = psutil.wait_procs(all_procs, timeout=30) - for p in alive: - p.kill() - - do_terminate = False - - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - pass - else: - finished = True - - sys.stdout.flush() - sys.stderr.flush() - if proc.returncode != 0: - logger.error(f"{cmd} exited with status = {proc.returncode}") - return proc.returncode - else: - server.serve( - model_id, - lora_adapters, - revision, - sharded, - quantize, - speculate, - dtype, - kv_cache_dtype, - trust_remote_code, - uds_path, - max_input_tokens, - ) + server.serve( + model_id, + lora_adapters, + revision, + sharded, + quantize, + speculate, + dtype, + kv_cache_dtype, + trust_remote_code, + uds_path, + max_input_tokens, + ) @app.command() diff --git a/backends/gaudi/server/text_generation_server/habana_quantization_env.py b/backends/gaudi/server/text_generation_server/habana_quantization_env.py deleted file mode 100644 index b03b7e26..00000000 --- a/backends/gaudi/server/text_generation_server/habana_quantization_env.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. - -import os -import habana_frameworks.torch as htorch - -quant_config = os.getenv("QUANT_CONFIG", "") -is_quantization_enabled = quant_config != "" - -if is_quantization_enabled: - os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") - os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") - os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") - os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") - os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") - - -def patch_scoped_linear_all_reduce(model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ( - ScopedLinearAllReduce, - ) - - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - patch_scoped_linear_all_reduce(module) - - -def setup_quantization(model): - if is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - -def prepare_model_for_quantization(model): - if is_quantization_enabled: - if model.config.model_type in [ - "llama", - "falcon", - "qwen2", - "starcoder2", - "gemma", - ]: - patch_scoped_linear_all_reduce(model) - from neural_compressor.torch.quantization import FP8Config, convert - - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - return model