Add Fp8 support (#42) (#71)

Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com>
Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com>
Co-authored-by: Grzegorz Morys <gmorys@habana.ai>
This commit is contained in:
jkaniecki 2024-02-23 11:52:28 +01:00 committed by GitHub
parent a490847702
commit c3bd8ef445
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 2 deletions

View File

@ -0,0 +1,15 @@
import os
import sys
assert "habana_frameworks" not in sys.modules
is_quantization_enabled = os.getenv("QUANT_CONFIG", "") != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault(
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")

View File

@ -4,15 +4,16 @@ import itertools
import time
import glob
from text_generation_server.utils.tokens import batch_top_tokens
import torch
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Dict
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from contextlib import nullcontext
from optimum.habana.utils import HabanaProfile, to_gb_rounded
@ -23,6 +24,7 @@ from optimum.habana.checkpoint_utils import (
write_checkpoints_json,
)
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
@ -491,6 +493,8 @@ class CausalLM(Model):
dtype: Optional[torch.dtype] = None,
):
device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
dtype = torch.bfloat16 if dtype is None else dtype
@ -555,6 +559,7 @@ class CausalLM(Model):
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
model = self.prepare_model_for_quantization(model)
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
@ -566,11 +571,13 @@ class CausalLM(Model):
revision=revision,
torch_dtype=dtype,
)
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device)
# wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
self.is_optimized_for_gaudi = True
@ -621,6 +628,36 @@ class CausalLM(Model):
self.hb_profer_started = False
self.step = 0
def setup_quantization(self, model):
if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model
def prepare_model_for_quantization(self, model):
if hq_env.is_quantization_enabled:
if model.config.model_type == "llama":
self.patch_scoped_linear_all_reduce(model)
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model
def finish_quantization_measurements(self, model):
if hq_env.is_quantization_enabled:
import habana_quantization_toolkit
habana_quantization_toolkit.finish_measurements(self.model)
return model
def patch_scoped_linear_all_reduce(self, model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
self.patch_scoped_linear_all_reduce(module)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch

View File

@ -204,6 +204,9 @@ def serve(
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
finally:
if hasattr(model,'finish_quantization_measurements'):
model.finish_quantization_measurements()
logger.info(
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(