diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index babf3d4b..96b120b2 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama=use_exllama, ) + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim) + try: + qweight = torch.cat( + [weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1) + + self._get_gptq_params(weights) + + qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1) + + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + def get_weights_row(self, weights: Weights, prefix: str): self._get_gptq_params(weights)