diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 2933aea2..61dd5115 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -239,6 +239,9 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_fp8(cls, weight, scale, input_scale, bias, dtype): + if FBGEMM_DYN_AVAILABLE: + # fbgemm needs float32 scales. + scale = scale.float() return cls( qweight=weight, scale=scale, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fc530b38..0e323b4f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -334,6 +334,7 @@ def get_model( model_type = config_dict.get("model_type", None) quantization_config = config_dict.get("quantization_config", None) + compression_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: @@ -344,6 +345,23 @@ def get_model( quantize = "fp8" else: log_master(logger.warning, f"Unknown quantization method {method}") + elif compression_config is not None: + # TODO: at some point we should probably fully parse the compression + # configuration to know which parameters are compressed. + config_groups = compression_config.get("config_groups") + if config_groups is not None: + for _, group in config_groups.items(): + weights_config = group.get("weights") + if weights_config is not None: + if ( + weights_config["type"] == "float" + and weights_config["num_bits"] == 8 + ): + log_master( + logger.info, "Auto selecting quantization method fp8" + ) + quantize = "fp8" + break if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: