mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Support LLM compressor FP8 checkpoints on H100
On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype. However, we wouldn't pick up fp8 quantization for models quantized with LLM compressor. This change adds enough parsing to detect if models have FP8-quantized weights.
This commit is contained in:
parent
ccaf9ff507
commit
7c05b0ba54
@ -239,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||||
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
|
# fbgemm needs float32 scales.
|
||||||
|
scale = scale.float()
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
@ -334,6 +334,7 @@ def get_model(
|
|||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
compression_config = config_dict.get("compression_config", None)
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
@ -344,6 +345,23 @@ def get_model(
|
|||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
else:
|
else:
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
elif compression_config is not None:
|
||||||
|
# TODO: at some point we should probably fully parse the compression
|
||||||
|
# configuration to know which parameters are compressed.
|
||||||
|
config_groups = compression_config.get("config_groups")
|
||||||
|
if config_groups is not None:
|
||||||
|
for _, group in config_groups.items():
|
||||||
|
weights_config = group.get("weights")
|
||||||
|
if weights_config is not None:
|
||||||
|
if (
|
||||||
|
weights_config["type"] == "float"
|
||||||
|
and weights_config["num_bits"] == 8
|
||||||
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, "Auto selecting quantization method fp8"
|
||||||
|
)
|
||||||
|
quantize = "fp8"
|
||||||
|
break
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
Loading…
Reference in New Issue
Block a user