Add Fp8 support (#42) (#71)

Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com>
Co-authored-by: Grzegorz Morys <gmorys@habana.ai>
This commit is contained in:
jkaniecki 2024-02-23 11:52:28 +01:00 committed by GitHub
parent a490847702
commit c3bd8ef445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 2 deletions

View File

@ -0,0 +1,15 @@
import os
import sys
assert "habana_frameworks" not in sys.modules
is_quantization_enabled = os.getenv("QUANT_CONFIG", "") != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault(
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")

View File

@ -4,15 +4,16 @@ import itertools
import time import time
import glob import glob
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from contextlib import nullcontext from contextlib import nullcontext
from optimum.habana.utils import HabanaProfile, to_gb_rounded from optimum.habana.utils import HabanaProfile, to_gb_rounded
@ -23,6 +24,7 @@ from optimum.habana.checkpoint_utils import (
write_checkpoints_json, write_checkpoints_json,
) )
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -491,6 +493,8 @@ class CausalLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
device = torch.device("hpu") device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
@ -555,6 +559,7 @@ class CausalLM(Model):
ds_inference_kwargs["checkpoint"] = checkpoints_json.name ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs) model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module model = model.module
model = self.prepare_model_for_quantization(model)
model = remove_kv_cache_from_output(model) model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph: if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True) model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
@ -566,11 +571,13 @@ class CausalLM(Model):
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
) )
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device) model = model.eval().to(device)
# wrap in hpu_graph only if self.enable_hpu_graph is set # wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model) model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph: if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True) model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
self.is_optimized_for_gaudi = True self.is_optimized_for_gaudi = True
@ -621,6 +628,36 @@ class CausalLM(Model):
self.hb_profer_started = False self.hb_profer_started = False
self.step = 0 self.step = 0
def setup_quantization(self, model):
if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model
def prepare_model_for_quantization(self, model):
if hq_env.is_quantization_enabled:
if model.config.model_type == "llama":
self.patch_scoped_linear_all_reduce(model)
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model
def finish_quantization_measurements(self, model):
if hq_env.is_quantization_enabled:
import habana_quantization_toolkit
habana_quantization_toolkit.finish_measurements(self.model)
return model
def patch_scoped_linear_all_reduce(self, model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
self.patch_scoped_linear_all_reduce(module)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch return CausalLMBatch

View File

@ -204,6 +204,9 @@ def serve(
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
finally:
if hasattr(model,'finish_quantization_measurements'):
model.finish_quantization_measurements()
logger.info( logger.info(
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format( "Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(