(review_comments) fix typo and added comments

This commit is contained in:
Mohit Sharma 2024-10-15 12:01:12 +00:00
parent b2b5024ec8
commit 1de96279e3
2 changed files with 19 additions and 2 deletions

View File

@ -165,7 +165,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
scale = scale.reshape(-1).expand(w.shape[0]) scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
if weights.get_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor( input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False f"{prefix}.input_scale", to_dtype=False
) )
@ -213,6 +213,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale") if weights.has_tensor(f"{p}.input_scale")
] ]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = ( input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max() torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0 if len(input_scale) != 0

View File

@ -342,6 +342,7 @@ def get_model(
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
config_groups = quantization_config.get("config_groups", None) config_groups = quantization_config.get("config_groups", None)
@ -352,7 +353,6 @@ def get_model(
log_master(logger.info, "Auto selecting quantization method fp8") log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8" quantize = "fp8"
elif config_groups is not None: elif config_groups is not None:
# Compression config renamed to quantization_config
# TODO: at some point we should probably fully parse the compression # TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed. # configuration to know which parameters are compressed.
for _, group in config_groups.items(): for _, group in config_groups.items():
@ -369,6 +369,22 @@ def get_model(
break break
else: else:
log_master(logger.warning, f"Unknown quantization method {method}") log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# For backward compatibility: compression_config is renamed to quantization_config
config_groups = compression_config.get("config_groups")
if config_groups is not None:
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]: