mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Co-authored-by: mrs303 <54661797+mrs303@users.noreply.github.com> Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Co-authored-by: Grzegorz Morys <gmorys@habana.ai>
This commit is contained in:
parent
a490847702
commit
c3bd8ef445
15
server/text_generation_server/habana_quantization_env.py
Normal file
15
server/text_generation_server/habana_quantization_env.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
assert "habana_frameworks" not in sys.modules
|
||||||
|
|
||||||
|
is_quantization_enabled = os.getenv("QUANT_CONFIG", "") != ""
|
||||||
|
|
||||||
|
if is_quantization_enabled:
|
||||||
|
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
|
||||||
|
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||||
|
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||||
|
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||||
|
os.environ.setdefault(
|
||||||
|
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||||
|
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
@ -4,15 +4,16 @@ import itertools
|
|||||||
import time
|
import time
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
|
||||||
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from optimum.habana.utils import HabanaProfile, to_gb_rounded
|
from optimum.habana.utils import HabanaProfile, to_gb_rounded
|
||||||
|
|
||||||
@ -23,6 +24,7 @@ from optimum.habana.checkpoint_utils import (
|
|||||||
write_checkpoints_json,
|
write_checkpoints_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
@ -491,6 +493,8 @@ class CausalLM(Model):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
device = torch.device("hpu")
|
device = torch.device("hpu")
|
||||||
|
if hq_env.is_quantization_enabled:
|
||||||
|
htorch.core.hpu_set_env()
|
||||||
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
|
||||||
@ -555,6 +559,7 @@ class CausalLM(Model):
|
|||||||
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
|
||||||
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
model = deepspeed.init_inference(model, **ds_inference_kwargs)
|
||||||
model = model.module
|
model = model.module
|
||||||
|
model = self.prepare_model_for_quantization(model)
|
||||||
model = remove_kv_cache_from_output(model)
|
model = remove_kv_cache_from_output(model)
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
@ -566,11 +571,13 @@ class CausalLM(Model):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
)
|
)
|
||||||
|
model = self.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
# wrap in hpu_graph only if self.enable_hpu_graph is set
|
# wrap in hpu_graph only if self.enable_hpu_graph is set
|
||||||
model = remove_kv_cache_from_output(model)
|
model = remove_kv_cache_from_output(model)
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
|
model = self.setup_quantization(model)
|
||||||
|
|
||||||
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
self.is_optimized_for_gaudi = True
|
self.is_optimized_for_gaudi = True
|
||||||
@ -621,6 +628,36 @@ class CausalLM(Model):
|
|||||||
self.hb_profer_started = False
|
self.hb_profer_started = False
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
def setup_quantization(self, model):
|
||||||
|
if hq_env.is_quantization_enabled:
|
||||||
|
htorch.core.quantization._mark_params_as_const(model)
|
||||||
|
htorch.core.quantization._check_params_as_const(model)
|
||||||
|
htorch.core.hpu_initialize(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def prepare_model_for_quantization(self, model):
|
||||||
|
if hq_env.is_quantization_enabled:
|
||||||
|
if model.config.model_type == "llama":
|
||||||
|
self.patch_scoped_linear_all_reduce(model)
|
||||||
|
import habana_quantization_toolkit
|
||||||
|
habana_quantization_toolkit.prep_model(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def finish_quantization_measurements(self, model):
|
||||||
|
if hq_env.is_quantization_enabled:
|
||||||
|
import habana_quantization_toolkit
|
||||||
|
habana_quantization_toolkit.finish_measurements(self.model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def patch_scoped_linear_all_reduce(self, model):
|
||||||
|
from deepspeed.module_inject.layers import LinearAllreduce
|
||||||
|
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if type(module) is LinearAllreduce:
|
||||||
|
SL = ScopedLinearAllReduce(mod=module)
|
||||||
|
setattr(model, name, SL)
|
||||||
|
self.patch_scoped_linear_all_reduce(module)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
|
@ -204,6 +204,9 @@ def serve(
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
finally:
|
||||||
|
if hasattr(model,'finish_quantization_measurements'):
|
||||||
|
model.finish_quantization_measurements()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(
|
"Starting Server : model_id= {}, revision = {} dtype = {} sharded = {} ".format(
|
||||||
|
Loading…
Reference in New Issue
Block a user