fix tp quant skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-02-24 17:27:14 +00:00
parent a332862510
commit bc4eb25d41

View File

@ -199,6 +199,11 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
return UnquantizedWeight(w)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -250,6 +255,11 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
w = torch.cat(
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
)
return UnquantizedWeight(w)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -328,6 +338,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
return UnquantizedWeight(w)
try: try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError: