From 46b14e6b286959324a7d135ef9320287bfc0ca07 Mon Sep 17 00:00:00 2001 From: Thanaji Rao Thakkalapelli Date: Fri, 18 Oct 2024 01:59:59 -0700 Subject: [PATCH] Remove all references to habana_quantization_toolkit for 1.18 (#229) --- README.md | 2 +- .../habana_quantization_env.py | 38 ++++++++++++----- .../models/causal_lm.py | 29 ++----------- .../models/vlm_causal_lm.py | 42 +++---------------- 4 files changed, 38 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index 38b3aa5a..2fa8836b 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ curl -N 127.0.0.1:8080/generate_stream \ ## Running TGI with FP8 Precision -TGI-Gaudi supports FP8 precision inference with INC (Intel Neural Compressor) and HQT (Habana Quantization Toolkit). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. From TGI-Gaudi 2.0.4 release, INC is used by default for quantization. HQT will be removed in future releases. To use HQT, disable INC by setting `-e USE_INC=0` in docker command. +TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. To run FP8 Inference: diff --git a/server/text_generation_server/habana_quantization_env.py b/server/text_generation_server/habana_quantization_env.py index 3c06fd09..e942fdcf 100644 --- a/server/text_generation_server/habana_quantization_env.py +++ b/server/text_generation_server/habana_quantization_env.py @@ -1,6 +1,7 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. import os +import habana_frameworks.torch as htorch quant_config = os.getenv("QUANT_CONFIG", "") is_quantization_enabled = quant_config != "" @@ -10,18 +11,35 @@ if is_quantization_enabled: os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") - os.environ.setdefault( - "UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") + os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") +def patch_scoped_linear_all_reduce(model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + patch_scoped_linear_all_reduce(module) + + +def setup_quantization(model): + if is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + def prepare_model_for_quantization(model): if is_quantization_enabled: - if os.getenv("USE_INC", "1") != "0": - from neural_compressor.torch.quantization import FP8Config, convert - config = FP8Config.from_json_file(quant_config) - model = convert(model, config) - else: - import habana_quantization_toolkit - habana_quantization_toolkit.prep_model(model) - return model + if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]: + patch_scoped_linear_all_reduce(model) + from neural_compressor.torch.quantization import FP8Config, convert + + config = FP8Config.from_json_file(quant_config) + model = convert(model, config) + return model diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 3ef04cc9..7ba959ee 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -629,7 +629,7 @@ class CausalLM(Model): model = self.get_deepspeed_model( model_id, dtype, revision ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) else: get_repo_root(model_id) @@ -648,7 +648,7 @@ class CausalLM(Model): trust_remote_code=trust_remote_code, **model_kwargs ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) model = model.eval().to(device) self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 @@ -667,7 +667,7 @@ class CausalLM(Model): "TORCH COMPILE", f'Torch compiling of model') model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - model = self.setup_quantization(model) + model = hq_env.setup_quantization(model) if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: raise ValueError(f"Model type {model.config.model_type} is not supported!") @@ -799,29 +799,6 @@ class CausalLM(Model): 'type': rope_scaling, 'factor': float(rope_factor) } - def setup_quantization(self, model): - if hq_env.is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - def prepare_model_for_quantization(self, model): - if hq_env.is_quantization_enabled: - if model.config.model_type == "llama": - self.patch_scoped_linear_all_reduce(model) - model = hq_env.prepare_model_for_quantization(model) - return model - - def patch_scoped_linear_all_reduce(self, model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - self.patch_scoped_linear_all_reduce(module) - @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 93cb179a..55dc7c07 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -193,7 +193,7 @@ class VlmCausalLMBatch(CausalLMBatch): device: torch.device, is_warmup: bool = False, ) -> "VlmCausalLMBatch": - + dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] @@ -536,7 +536,7 @@ class VlmCausalLM(Model): model = self.get_deepspeed_model( model_class, model_id, dtype, revision ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) else: get_repo_root(model_id) @@ -555,7 +555,7 @@ class VlmCausalLM(Model): trust_remote_code=trust_remote_code, **model_kwargs ) - model = self.prepare_model_for_quantization(model) + model = hq_env.prepare_model_for_quantization(model) model = model.eval().to(device) self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 @@ -565,13 +565,13 @@ class VlmCausalLM(Model): from habana_frameworks.torch.hpu import wrap_in_hpu_graph model = wrap_in_hpu_graph(model, disable_tensor_cache=True) else: - if LAZY_MODE == 0: + if LAZY_MODE == 0: # It is said that "keep_input_mutations" is safe for inference to be done dbg_trace( "TORCH COMPILE", f'Torch compiling of model') model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) - model = self.setup_quantization(model) + model = hq_env.setup_quantization(model) if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: raise ValueError(f"Model type {model.config.model_type} is not supported!") @@ -703,36 +703,6 @@ class VlmCausalLM(Model): 'type': rope_scaling, 'factor': float(rope_factor) } - def setup_quantization(self, model): - if hq_env.is_quantization_enabled: - htorch.core.quantization._mark_params_as_const(model) - htorch.core.quantization._check_params_as_const(model) - htorch.core.hpu_initialize(model) - return model - - def prepare_model_for_quantization(self, model): - if hq_env.is_quantization_enabled: - if model.config.model_type == "llama": - self.patch_scoped_linear_all_reduce(model) - import habana_quantization_toolkit - habana_quantization_toolkit.prep_model(model) - return model - - def finish_quantization_measurements(self, model): - if hq_env.is_quantization_enabled: - import habana_quantization_toolkit - habana_quantization_toolkit.finish_measurements(self.model) - return model - - def patch_scoped_linear_all_reduce(self, model): - from deepspeed.module_inject.layers import LinearAllreduce - from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce - for name, module in model.named_children(): - if type(module) is LinearAllreduce: - SL = ScopedLinearAllReduce(mod=module) - setattr(model, name, SL) - self.patch_scoped_linear_all_reduce(module) - def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) @@ -906,7 +876,7 @@ class VlmCausalLM(Model): bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, ) elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]): - # Don't schedule next forward if max_new_tokens for all requests equals 1 + # Don't schedule next forward if max_new_tokens for all requests equals 1 # - we've already generated the first and only needed token in the prefill phase pass else: