From 7c05b0ba54ac2aa8a8fc2ef4823107121e27581a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 24 Sep 2024 08:27:52 +0000 Subject: [PATCH] Support LLM compressor FP8 checkpoints on H100 On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype. However, we wouldn't pick up fp8 quantization for models quantized with LLM compressor. This change adds enough parsing to detect if models have FP8-quantized weights. --- server/text_generation_server/layers/fp8.py | 3 +++ .../text_generation_server/models/__init__.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 2933aea2..61dd5115 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -239,6 +239,9 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_fp8(cls, weight, scale, input_scale, bias, dtype): + if FBGEMM_DYN_AVAILABLE: + # fbgemm needs float32 scales. + scale = scale.float() return cls( qweight=weight, scale=scale, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fc530b38..0e323b4f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -334,6 +334,7 @@ def get_model( model_type = config_dict.get("model_type", None) quantization_config = config_dict.get("quantization_config", None) + compression_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: @@ -344,6 +345,23 @@ def get_model( quantize = "fp8" else: log_master(logger.warning, f"Unknown quantization method {method}") + elif compression_config is not None: + # TODO: at some point we should probably fully parse the compression + # configuration to know which parameters are compressed. + config_groups = compression_config.get("config_groups") + if config_groups is not None: + for _, group in config_groups.items(): + weights_config = group.get("weights") + if weights_config is not None: + if ( + weights_config["type"] == "float" + and weights_config["num_bits"] == 8 + ): + log_master( + logger.info, "Auto selecting quantization method fp8" + ) + quantize = "fp8" + break if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: