mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com> Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: Grzegorz Morys <gmorys@habana.ai>
This commit is contained in:
parent
a490847702
commit
c3bd8ef445
15
server/text_generation_server/habana_quantization_env.py
Normal file
15
server/text_generation_server/habana_quantization_env.py
Normal file
@ -0,0 +1,15 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
assert "habana_frameworks" not in sys.modules
|
||||
|
||||
is_quantization_enabled = os.getenv("QUANT_CONFIG", "") != ""
|
||||
|
||||
if is_quantization_enabled:
|
||||
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
|
||||
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||
os.environ.setdefault(
|
||||
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
@ -4,15 +4,16 @@ import itertools
|
||||
import time
|
||||
import glob
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
|
||||
import text_generation_server.habana_quantization_env as hq_env
|
||||
import habana_frameworks.torch as htorch
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
from contextlib import nullcontext
|
||||
from optimum.habana.utils import HabanaProfile, to_gb_rounded
|
||||
|
||||
@ -23,6 +24,7 @@ from optimum.habana.checkpoint_utils import (
|
||||
write_checkpoints_json,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
@ -491,6 +493,8 @@ class CausalLM(Model):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = torch.device("hpu")
|
||||
if hq_env.is_quantization_enabled:
|
||||
htorch.core.hpu_set_env()
|
||||
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
|
||||
@ -555,6 +559,7 @@ class CausalLM(Model):
|
||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||
model = model.module
|
||||
model = self.prepare_model_for_quantization(model)
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
@ -566,11 +571,13 @@ class CausalLM(Model):
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
model = self.prepare_model_for_quantization(model)
|
||||
model = model.eval().to(device)
|
||||
# wrap in hpu_graph only if self.enable_hpu_graph is set
|
||||
model = remove_kv_cache_from_output(model)
|
||||
if self.enable_hpu_graph:
|
||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
model = self.setup_quantization(model)
|
||||
|
||||
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||
self.is_optimized_for_gaudi = True
|
||||
@ -621,6 +628,36 @@ class CausalLM(Model):
|
||||
self.hb_profer_started = False
|
||||
self.step = 0
|
||||
|
||||
def setup_quantization(self, model):
|
||||
if hq_env.is_quantization_enabled:
|
||||
htorch.core.quantization._mark_params_as_const(model)
|
||||
htorch.core.quantization._check_params_as_const(model)
|
||||
htorch.core.hpu_initialize(model)
|
||||
return model
|
||||
|
||||
def prepare_model_for_quantization(self, model):
|
||||
if hq_env.is_quantization_enabled:
|
||||
if model.config.model_type == "llama":
|
||||
self.patch_scoped_linear_all_reduce(model)
|
||||
import habana_quantization_toolkit
|
||||
habana_quantization_toolkit.prep_model(model)
|
||||
return model
|
||||
|
||||
def finish_quantization_measurements(self, model):
|
||||
if hq_env.is_quantization_enabled:
|
||||
import habana_quantization_toolkit
|
||||
habana_quantization_toolkit.finish_measurements(self.model)
|
||||
return model
|
||||
|
||||
def patch_scoped_linear_all_reduce(self, model):
|
||||
from deepspeed.module_inject.layers import LinearAllreduce
|
||||
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
||||
for name, module in model.named_children():
|
||||
if type(module) is LinearAllreduce:
|
||||
SL = ScopedLinearAllReduce(mod=module)
|
||||
setattr(model, name, SL)
|
||||
self.patch_scoped_linear_all_reduce(module)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
|
@ -204,6 +204,9 @@ def serve(
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Signal received. Shutting down")
|
||||
await server.stop(0)
|
||||
finally:
|
||||
if hasattr(model,'finish_quantization_measurements'):
|
||||
model.finish_quantization_measurements()
|
||||
|
||||
logger.info(
|
||||
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(
|
||||
|
Loading…
Reference in New Issue
Block a user