Support LLM compressor FP8 checkpoints on H100

On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype.
However, we wouldn't pick up fp8 quantization for models quantized with
LLM compressor. This change adds enough parsing to detect if models have
FP8-quantized weights.
This commit is contained in:
Daniël de Kok 2024-09-24 08:27:52 +00:00
parent ccaf9ff507
commit 7c05b0ba54
2 changed files with 21 additions and 0 deletions

View File

@ -239,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
@classmethod @classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype): def from_fp8(cls, weight, scale, input_scale, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,

View File

@ -334,6 +334,7 @@ def get_model(
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}: if method in {"gptq", "awq", "exl2"}:
@ -344,6 +345,23 @@ def get_model(
quantize = "fp8" quantize = "fp8"
else: else:
log_master(logger.warning, f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed.
config_groups = compression_config.get("config_groups")
if config_groups is not None:
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]: