mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix tp quant skip
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
a332862510
commit
bc4eb25d41
@ -199,6 +199,11 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
return UnquantizedWeight(w)
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
@ -250,6 +255,11 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||
w = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
|
||||
)
|
||||
return UnquantizedWeight(w)
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
@ -328,6 +338,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return UnquantizedWeight(w)
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
|
Loading…
Reference in New Issue
Block a user