From bc4eb25d415c6448d3bcbc46bf7c6655382f9dd9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 24 Feb 2025 17:27:14 +0000 Subject: [PATCH] fix tp quant skip Signed-off-by: jiqing-feng --- .../text_generation_server/layers/gptq/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index d443f94a..b2371967 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -199,6 +199,11 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + return UnquantizedWeight(w) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -250,6 +255,11 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + w = torch.cat( + [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes], dim=dim + ) + return UnquantizedWeight(w) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -328,6 +338,9 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + return UnquantizedWeight(w) try: qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: