From c3bd8ef4452386a1b12c740838b9d35e8539ea0d Mon Sep 17 00:00:00 2001 From: jkaniecki <153085639+jkaniecki@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:52:28 +0100 Subject: [PATCH] Add Fp8 support (#42) (#71) Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com> Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: Grzegorz Morys --- .../habana_quantization_env.py | 15 +++++++ .../models/causal_lm.py | 41 ++++++++++++++++++- server/text_generation_server/server.py | 3 ++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 server/text_generation_server/habana_quantization_env.py diff --git a/server/text_generation_server/habana_quantization_env.py b/server/text_generation_server/habana_quantization_env.py new file mode 100644 index 00000000..351f88a3 --- /dev/null +++ b/server/text_generation_server/habana_quantization_env.py @@ -0,0 +1,15 @@ +import os +import sys + +assert "habana_frameworks" not in sys.modules + +is_quantization_enabled = os.getenv("QUANT_CONFIG", "") != "" + +if is_quantization_enabled: + os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true") + os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true") + os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false") + os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false") + os.environ.setdefault( + "UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av") + os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a0b1bd82..42c9356f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -4,15 +4,16 @@ import itertools import time import glob -from text_generation_server.utils.tokens import batch_top_tokens import torch from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig from typing import Optional, Tuple, List, Type, Dict -from habana_frameworks.torch.hpu import wrap_in_hpu_graph + +import text_generation_server.habana_quantization_env as hq_env import habana_frameworks.torch as htorch +from habana_frameworks.torch.hpu import wrap_in_hpu_graph from contextlib import nullcontext from optimum.habana.utils import HabanaProfile, to_gb_rounded @@ -23,6 +24,7 @@ from optimum.habana.checkpoint_utils import ( write_checkpoints_json, ) +from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, @@ -491,6 +493,8 @@ class CausalLM(Model): dtype: Optional[torch.dtype] = None, ): device = torch.device("hpu") + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() dtype = torch.bfloat16 if dtype is None else dtype @@ -555,6 +559,7 @@ class CausalLM(Model): ds_inference_kwargs["checkpoint"] = checkpoints_json.name model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module + model = self.prepare_model_for_quantization(model) model = remove_kv_cache_from_output(model) if self.enable_hpu_graph: model = wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -566,11 +571,13 @@ class CausalLM(Model): revision=revision, torch_dtype=dtype, ) + model = self.prepare_model_for_quantization(model) model = model.eval().to(device) # wrap in hpu_graph only if self.enable_hpu_graph is set model = remove_kv_cache_from_output(model) if self.enable_hpu_graph: model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + model = self.setup_quantization(model) if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: self.is_optimized_for_gaudi = True @@ -621,6 +628,36 @@ class CausalLM(Model): self.hb_profer_started = False self.step = 0 + def setup_quantization(self, model): + if hq_env.is_quantization_enabled: + htorch.core.quantization._mark_params_as_const(model) + htorch.core.quantization._check_params_as_const(model) + htorch.core.hpu_initialize(model) + return model + + def prepare_model_for_quantization(self, model): + if hq_env.is_quantization_enabled: + if model.config.model_type == "llama": + self.patch_scoped_linear_all_reduce(model) + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) + return model + + def finish_quantization_measurements(self, model): + if hq_env.is_quantization_enabled: + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(self.model) + return model + + def patch_scoped_linear_all_reduce(self, model): + from deepspeed.module_inject.layers import LinearAllreduce + from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce + for name, module in model.named_children(): + if type(module) is LinearAllreduce: + SL = ScopedLinearAllReduce(mod=module) + setattr(model, name, SL) + self.patch_scoped_linear_all_reduce(module) + @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 41676486..2f41ec94 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -204,6 +204,9 @@ def serve( except KeyboardInterrupt: logger.info("Signal received. Shutting down") await server.stop(0) + finally: + if hasattr(model,'finish_quantization_measurements'): + model.finish_quantization_measurements() logger.info( "Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(