mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
(review_comments) fix typo and added comments
This commit is contained in:
parent
b2b5024ec8
commit
1de96279e3
@ -165,7 +165,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
scale = scale.reshape(-1).expand(w.shape[0])
|
||||
|
||||
input_scale = None
|
||||
if weights.get_tensor(f"{prefix}.input_scale"):
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
)
|
||||
@ -213,6 +213,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
|
@ -342,6 +342,7 @@ def get_model(
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
compression_config = config_dict.get("compression_config", None)
|
||||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
config_groups = quantization_config.get("config_groups", None)
|
||||
@ -352,7 +353,6 @@ def get_model(
|
||||
log_master(logger.info, "Auto selecting quantization method fp8")
|
||||
quantize = "fp8"
|
||||
elif config_groups is not None:
|
||||
# Compression config renamed to quantization_config
|
||||
# TODO: at some point we should probably fully parse the compression
|
||||
# configuration to know which parameters are compressed.
|
||||
for _, group in config_groups.items():
|
||||
@ -369,6 +369,22 @@ def get_model(
|
||||
break
|
||||
else:
|
||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||
elif compression_config is not None:
|
||||
# For backward compatibility: compression_config is renamed to quantization_config
|
||||
config_groups = compression_config.get("config_groups")
|
||||
if config_groups is not None:
|
||||
for _, group in config_groups.items():
|
||||
weights_config = group.get("weights")
|
||||
if weights_config is not None:
|
||||
if (
|
||||
weights_config["type"] == "float"
|
||||
and weights_config["num_bits"] == 8
|
||||
):
|
||||
log_master(
|
||||
logger.info, "Auto selecting quantization method fp8"
|
||||
)
|
||||
quantize = "fp8"
|
||||
break
|
||||
|
||||
if dtype is None:
|
||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
|
Loading…
Reference in New Issue
Block a user