diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index f7710145..c626ddb8 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -165,7 +165,7 @@ class HybridFP8UnquantLoader(WeightsLoader): scale = scale.reshape(-1).expand(w.shape[0]) input_scale = None - if weights.get_tensor(f"{prefix}.input_scale"): + if weights.has_tensor(f"{prefix}.input_scale"): input_scale = weights.get_tensor( f"{prefix}.input_scale", to_dtype=False ) @@ -213,6 +213,7 @@ class HybridFP8UnquantLoader(WeightsLoader): for p, shape in zip(prefixes, shapes) if weights.has_tensor(f"{p}.input_scale") ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) input_scale = ( torch.cat(input_scale, dim=0).reshape(-1).max() if len(input_scale) != 0 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c9886092..7891a15d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -342,6 +342,7 @@ def get_model( model_type = config_dict.get("model_type", None) quantization_config = config_dict.get("quantization_config", None) + compression_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) config_groups = quantization_config.get("config_groups", None) @@ -352,7 +353,6 @@ def get_model( log_master(logger.info, "Auto selecting quantization method fp8") quantize = "fp8" elif config_groups is not None: - # Compression config renamed to quantization_config # TODO: at some point we should probably fully parse the compression # configuration to know which parameters are compressed. for _, group in config_groups.items(): @@ -369,6 +369,22 @@ def get_model( break else: log_master(logger.warning, f"Unknown quantization method {method}") + elif compression_config is not None: + # For backward compatibility: compression_config is renamed to quantization_config + config_groups = compression_config.get("config_groups") + if config_groups is not None: + for _, group in config_groups.items(): + weights_config = group.get("weights") + if weights_config is not None: + if ( + weights_config["type"] == "float" + and weights_config["num_bits"] == 8 + ): + log_master( + logger.info, "Auto selecting quantization method fp8" + ) + quantize = "fp8" + break if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: