mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-16 22:32:07 +00:00
Remove all references to habana_quantization_toolkit for 1.18 (#229)
This commit is contained in:
parent
21c13ff3a6
commit
46b14e6b28
@ -285,7 +285,7 @@ curl -N 127.0.0.1:8080/generate_stream \
|
|||||||
|
|
||||||
## Running TGI with FP8 Precision
|
## Running TGI with FP8 Precision
|
||||||
|
|
||||||
TGI-Gaudi supports FP8 precision inference with INC (Intel Neural Compressor) and HQT (Habana Quantization Toolkit). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. From TGI-Gaudi 2.0.4 release, INC is used by default for quantization. HQT will be removed in future releases. To use HQT, disable INC by setting `-e USE_INC=0` in docker command.
|
TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command.
|
||||||
|
|
||||||
To run FP8 Inference:
|
To run FP8 Inference:
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
|
||||||
quant_config = os.getenv("QUANT_CONFIG", "")
|
quant_config = os.getenv("QUANT_CONFIG", "")
|
||||||
is_quantization_enabled = quant_config != ""
|
is_quantization_enabled = quant_config != ""
|
||||||
@ -10,18 +11,35 @@ if is_quantization_enabled:
|
|||||||
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||||
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||||
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||||
os.environ.setdefault(
|
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||||
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
|
||||||
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_scoped_linear_all_reduce(model):
|
||||||
|
from deepspeed.module_inject.layers import LinearAllreduce
|
||||||
|
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
||||||
|
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if type(module) is LinearAllreduce:
|
||||||
|
SL = ScopedLinearAllReduce(mod=module)
|
||||||
|
setattr(model, name, SL)
|
||||||
|
patch_scoped_linear_all_reduce(module)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_quantization(model):
|
||||||
|
if is_quantization_enabled:
|
||||||
|
htorch.core.quantization._mark_params_as_const(model)
|
||||||
|
htorch.core.quantization._check_params_as_const(model)
|
||||||
|
htorch.core.hpu_initialize(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_quantization(model):
|
def prepare_model_for_quantization(model):
|
||||||
if is_quantization_enabled:
|
if is_quantization_enabled:
|
||||||
if os.getenv("USE_INC", "1") != "0":
|
if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]:
|
||||||
from neural_compressor.torch.quantization import FP8Config, convert
|
patch_scoped_linear_all_reduce(model)
|
||||||
config = FP8Config.from_json_file(quant_config)
|
from neural_compressor.torch.quantization import FP8Config, convert
|
||||||
model = convert(model, config)
|
|
||||||
else:
|
config = FP8Config.from_json_file(quant_config)
|
||||||
import habana_quantization_toolkit
|
model = convert(model, config)
|
||||||
habana_quantization_toolkit.prep_model(model)
|
return model
|
||||||
return model
|
|
||||||
|
@ -629,7 +629,7 @@ class CausalLM(Model):
|
|||||||
model = self.get_deepspeed_model(
|
model = self.get_deepspeed_model(
|
||||||
model_id, dtype, revision
|
model_id, dtype, revision
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
|
|
||||||
@ -648,7 +648,7 @@ class CausalLM(Model):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
@ -667,7 +667,7 @@ class CausalLM(Model):
|
|||||||
"TORCH COMPILE", f'Torch compiling of model')
|
"TORCH COMPILE", f'Torch compiling of model')
|
||||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
model = self.setup_quantization(model)
|
model = hq_env.setup_quantization(model)
|
||||||
|
|
||||||
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
||||||
@ -799,29 +799,6 @@ class CausalLM(Model):
|
|||||||
'type': rope_scaling, 'factor': float(rope_factor)
|
'type': rope_scaling, 'factor': float(rope_factor)
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
htorch.core.quantization._mark_params_as_const(model)
|
|
||||||
htorch.core.quantization._check_params_as_const(model)
|
|
||||||
htorch.core.hpu_initialize(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def prepare_model_for_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
if model.config.model_type == "llama":
|
|
||||||
self.patch_scoped_linear_all_reduce(model)
|
|
||||||
model = hq_env.prepare_model_for_quantization(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def patch_scoped_linear_all_reduce(self, model):
|
|
||||||
from deepspeed.module_inject.layers import LinearAllreduce
|
|
||||||
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if type(module) is LinearAllreduce:
|
|
||||||
SL = ScopedLinearAllReduce(mod=module)
|
|
||||||
setattr(model, name, SL)
|
|
||||||
self.patch_scoped_linear_all_reduce(module)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[CausalLMBatch]:
|
def batch_type(self) -> Type[CausalLMBatch]:
|
||||||
return CausalLMBatch
|
return CausalLMBatch
|
||||||
|
@ -193,7 +193,7 @@ class VlmCausalLMBatch(CausalLMBatch):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
is_warmup: bool = False,
|
is_warmup: bool = False,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
|
|
||||||
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
|
||||||
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
|
||||||
|
|
||||||
@ -536,7 +536,7 @@ class VlmCausalLM(Model):
|
|||||||
model = self.get_deepspeed_model(
|
model = self.get_deepspeed_model(
|
||||||
model_class, model_id, dtype, revision
|
model_class, model_id, dtype, revision
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
|
|
||||||
@ -555,7 +555,7 @@ class VlmCausalLM(Model):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = hq_env.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
@ -565,13 +565,13 @@ class VlmCausalLM(Model):
|
|||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
else:
|
else:
|
||||||
if LAZY_MODE == 0:
|
if LAZY_MODE == 0:
|
||||||
# It is said that "keep_input_mutations" is safe for inference to be done
|
# It is said that "keep_input_mutations" is safe for inference to be done
|
||||||
dbg_trace(
|
dbg_trace(
|
||||||
"TORCH COMPILE", f'Torch compiling of model')
|
"TORCH COMPILE", f'Torch compiling of model')
|
||||||
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
|
||||||
|
|
||||||
model = self.setup_quantization(model)
|
model = hq_env.setup_quantization(model)
|
||||||
|
|
||||||
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
|
||||||
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
raise ValueError(f"Model type {model.config.model_type} is not supported!")
|
||||||
@ -703,36 +703,6 @@ class VlmCausalLM(Model):
|
|||||||
'type': rope_scaling, 'factor': float(rope_factor)
|
'type': rope_scaling, 'factor': float(rope_factor)
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
htorch.core.quantization._mark_params_as_const(model)
|
|
||||||
htorch.core.quantization._check_params_as_const(model)
|
|
||||||
htorch.core.hpu_initialize(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def prepare_model_for_quantization(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
if model.config.model_type == "llama":
|
|
||||||
self.patch_scoped_linear_all_reduce(model)
|
|
||||||
import habana_quantization_toolkit
|
|
||||||
habana_quantization_toolkit.prep_model(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def finish_quantization_measurements(self, model):
|
|
||||||
if hq_env.is_quantization_enabled:
|
|
||||||
import habana_quantization_toolkit
|
|
||||||
habana_quantization_toolkit.finish_measurements(self.model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def patch_scoped_linear_all_reduce(self, model):
|
|
||||||
from deepspeed.module_inject.layers import LinearAllreduce
|
|
||||||
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
|
|
||||||
for name, module in model.named_children():
|
|
||||||
if type(module) is LinearAllreduce:
|
|
||||||
SL = ScopedLinearAllReduce(mod=module)
|
|
||||||
setattr(model, name, SL)
|
|
||||||
self.patch_scoped_linear_all_reduce(module)
|
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
@ -906,7 +876,7 @@ class VlmCausalLM(Model):
|
|||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
|
||||||
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
# Don't schedule next forward if max_new_tokens for all requests equals 1
|
||||||
# - we've already generated the first and only needed token in the prefill phase
|
# - we've already generated the first and only needed token in the prefill phase
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user