Minor refactor

This commit is contained in:
Abhinav Kulkarni 2023-09-23 10:12:26 +00:00
parent 5d0973f484
commit 054930fbbe
2 changed files with 33 additions and 48 deletions

View File

@ -251,7 +251,7 @@ def get_linear(weight, bias, quantize):
)
elif quantize == "awq":
try:
qweight, qzeros, scales, bits, groupsize = weight
qweight, qzeros, scales, _, bits, groupsize, _ = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."

View File

@ -135,34 +135,29 @@ class Weights:
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
"""
if quantize == "gptq":
if quantize in ["gptq", "awq"]:
try:
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quantize == "gptq":
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
else:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
g_idx = self.get_tensor(f"{prefix}.g_idx")
try:
g_idx = self.get_tensor(f"{prefix}.g_idx")
except RuntimeError:
g_idx = None
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
if quantize == "awq":
try:
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, bits, groupsize)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
@ -184,15 +179,20 @@ class Weights:
return weight
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
if quantize in ["gptq", "awq"]:
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quantize == "gptq":
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
else:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
@ -200,32 +200,17 @@ class Weights:
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
try:
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
except RuntimeError:
g_idx = None
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
elif quantize == "awq":
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, bits, groupsize)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
@ -248,7 +233,7 @@ class Weights:
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
if quantize in "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()
@ -327,7 +312,7 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
weight = (qweight, qzeros, scales, bits, groupsize)
weight = (qweight, qzeros, scales, None, bits, groupsize, None)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight