Enable quantization with INC (#203)

This commit is contained in:
Thanaji Rao Thakkalapelli 2024-08-26 01:55:37 -07:00 committed by GitHub
parent ea48ae169a
commit 0c3239e710
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 12 deletions

View File

@ -109,7 +109,7 @@ For more information and documentation about Text Generation Inference, checkout
## Running TGI with FP8 precision
TGI supports FP8 precision runs within the limits provided by [Habana Quantization Toolkit](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). Models with FP8 can be ran by properly setting QUANT_CONFIG environment variable. Detailed instruction on how to use that variable can be found in [Optimum Habana FP8 guide](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8). Summarising that instruction in TGI cases:
TGI supports FP8 precision runs within the limits provided by [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). Models with FP8 can be ran by properly setting QUANT_CONFIG environment variable. Detailed instruction on how to use that variable can be found in [Optimum Habana FP8 guide](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8). From 2.0.4 release, Intel Neural Compressor (INC) is used by default for measuring and quantization. Habana Quantization Toolkit(HQT) will be removed in future releases. To use HQT, disable INC by setting `-e USE_INC=0`. Summarising that instruction in TGI cases:
1. Measure quantization statistics of requested model by using [Optimum Habana measurement script](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation#running-with-fp8:~:text=use_deepspeed%20%2D%2Dworld_size%208-,run_lm_eval.py,-%5C%0A%2Do%20acc_70b_bs1_measure.txt)
2. Run requested model in TGI with proper QUANT_CONFIG setting - e.g. `-e QUANT_CONFIG=./quantization_config/maxabs_quant.json`.

View File

@ -5,7 +5,8 @@ import sys
assert "habana_frameworks" not in sys.modules
is_quantization_enabled = os.getenv("QUANT_CONFIG", "") != ""
quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
@ -15,3 +16,15 @@ if is_quantization_enabled:
os.environ.setdefault(
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
def prepare_model_for_quantization(model):
if is_quantization_enabled:
if os.getenv("USE_INC", "1") != "0":
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
else:
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model

View File

@ -658,7 +658,7 @@ class CausalLM(Model):
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else:
if LAZY_MODE == 0:
if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace(
"TORCH COMPILE", f'Torch compiling of model')
@ -807,14 +807,7 @@ class CausalLM(Model):
if hq_env.is_quantization_enabled:
if model.config.model_type == "llama":
self.patch_scoped_linear_all_reduce(model)
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model
def finish_quantization_measurements(self, model):
if hq_env.is_quantization_enabled:
import habana_quantization_toolkit
habana_quantization_toolkit.finish_measurements(self.model)
model = hq_env.prepare_model_for_quantization(model)
return model
def patch_scoped_linear_all_reduce(self, model):
@ -995,7 +988,7 @@ class CausalLM(Model):
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
)
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
# Don't schedule next forward if max_new_tokens for all requests equals 1
# Don't schedule next forward if max_new_tokens for all requests equals 1
# - we've already generated the first and only needed token in the prefill phase
pass
else: