mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
(review_comments) fix typo and added comments
This commit is contained in:
parent
b2b5024ec8
commit
1de96279e3
@ -165,7 +165,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.get_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = weights.get_tensor(
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
)
|
)
|
||||||
@ -213,6 +213,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
if weights.has_tensor(f"{p}.input_scale")
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
]
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
input_scale = (
|
input_scale = (
|
||||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
if len(input_scale) != 0
|
if len(input_scale) != 0
|
||||||
|
@ -342,6 +342,7 @@ def get_model(
|
|||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
compression_config = config_dict.get("compression_config", None)
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
config_groups = quantization_config.get("config_groups", None)
|
config_groups = quantization_config.get("config_groups", None)
|
||||||
@ -352,7 +353,6 @@ def get_model(
|
|||||||
log_master(logger.info, "Auto selecting quantization method fp8")
|
log_master(logger.info, "Auto selecting quantization method fp8")
|
||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
elif config_groups is not None:
|
elif config_groups is not None:
|
||||||
# Compression config renamed to quantization_config
|
|
||||||
# TODO: at some point we should probably fully parse the compression
|
# TODO: at some point we should probably fully parse the compression
|
||||||
# configuration to know which parameters are compressed.
|
# configuration to know which parameters are compressed.
|
||||||
for _, group in config_groups.items():
|
for _, group in config_groups.items():
|
||||||
@ -369,6 +369,22 @@ def get_model(
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
elif compression_config is not None:
|
||||||
|
# For backward compatibility: compression_config is renamed to quantization_config
|
||||||
|
config_groups = compression_config.get("config_groups")
|
||||||
|
if config_groups is not None:
|
||||||
|
for _, group in config_groups.items():
|
||||||
|
weights_config = group.get("weights")
|
||||||
|
if weights_config is not None:
|
||||||
|
if (
|
||||||
|
weights_config["type"] == "float"
|
||||||
|
and weights_config["num_bits"] == 8
|
||||||
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, "Auto selecting quantization method fp8"
|
||||||
|
)
|
||||||
|
quantize = "fp8"
|
||||||
|
break
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
Loading…
Reference in New Issue
Block a user