mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
fix tp quant skip
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
a332862510
commit
bc4eb25d41
@ -199,6 +199,11 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
return UnquantizedWeight(w)
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
@ -250,6 +255,11 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||||
|
w = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim
|
||||||
|
)
|
||||||
|
return UnquantizedWeight(w)
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
@ -328,6 +338,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
return UnquantizedWeight(w)
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
Loading…
Reference in New Issue
Block a user