From 054930fbbe54f2e44b3fe6d924729b4df4baa3b7 Mon Sep 17 00:00:00 2001 From: Abhinav Kulkarni Date: Sat, 23 Sep 2023 10:12:26 +0000 Subject: [PATCH] Minor refactor --- server/text_generation_server/utils/layers.py | 2 +- .../text_generation_server/utils/weights.py | 79 ++++++++----------- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 3fb3766a..cfec5859 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -251,7 +251,7 @@ def get_linear(weight, bias, quantize): ) elif quantize == "awq": try: - qweight, qzeros, scales, bits, groupsize = weight + qweight, qzeros, scales, _, bits, groupsize, _ = weight except Exception: raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index c5562a4f..fdeabbe6 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -135,34 +135,29 @@ class Weights: Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being already alternating Q,K,V within the main tensor """ - if quantize == "gptq": + if quantize in ["gptq", "awq"]: try: qweight = self._get_qweight(f"{prefix}.qweight") except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) + if quantize == "gptq": + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + else: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - g_idx = self.get_tensor(f"{prefix}.g_idx") + try: + g_idx = self.get_tensor(f"{prefix}.g_idx") + except RuntimeError: + g_idx = None bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) - if quantize == "awq": - try: - qweight = self._get_qweight(f"{prefix}.qweight") - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - qzeros = self._get_qweight(f"{prefix}.qzeros") - scales = self._get_qweight(f"{prefix}.scales") - scales = scales.to(dtype=self.dtype) - - bits, groupsize = self._get_gptq_params() - weight = (qweight, qzeros, scales, bits, groupsize) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] @@ -184,15 +179,20 @@ class Weights: return weight def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "gptq": + if quantize in ["gptq", "awq"]: try: qweight = torch.cat( [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 ) except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) + if quantize == "gptq": + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + else: + raise RuntimeError( + "Cannot load `awq` weight, make sure the model is already quantized" + ) qzeros = torch.cat( [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 @@ -200,32 +200,17 @@ class Weights: scales = torch.cat( [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] + + try: + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + except RuntimeError: + g_idx = None bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) - elif quantize == "awq": - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - bits, groupsize = self._get_gptq_params() - weight = (qweight, qzeros, scales, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -248,7 +233,7 @@ class Weights: return tensor def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "gptq": + if quantize in "gptq": use_exllama = True bits, groupsize = self._get_gptq_params() @@ -327,7 +312,7 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - weight = (qweight, qzeros, scales, bits, groupsize) + weight = (qweight, qzeros, scales, None, bits, groupsize, None) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight