fix tp quant skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-02-24 17:27:14 +00:00
parent a332862510
commit bc4eb25d41

View File

@ -199,6 +199,11 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
return UnquantizedWeight(w)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -250,6 +255,11 @@ class GPTQWeightsLoader(WeightsLoader):
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
w = torch.cat(
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
)
return UnquantizedWeight(w)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -328,6 +338,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
return UnquantizedWeight(w)
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: