mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Minor refactor
This commit is contained in:
parent
5d0973f484
commit
054930fbbe
@ -251,7 +251,7 @@ def get_linear(weight, bias, quantize):
|
|||||||
)
|
)
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
try:
|
try:
|
||||||
qweight, qzeros, scales, bits, groupsize = weight
|
qweight, qzeros, scales, _, bits, groupsize, _ = weight
|
||||||
except Exception:
|
except Exception:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||||
|
@ -135,34 +135,29 @@ class Weights:
|
|||||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||||
already alternating Q,K,V within the main tensor
|
already alternating Q,K,V within the main tensor
|
||||||
"""
|
"""
|
||||||
if quantize == "gptq":
|
if quantize in ["gptq", "awq"]:
|
||||||
try:
|
try:
|
||||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
if quantize == "gptq":
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
raise RuntimeError(
|
||||||
)
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
try:
|
||||||
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
except RuntimeError:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
if quantize == "awq":
|
|
||||||
try:
|
|
||||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
|
||||||
scales = scales.to(dtype=self.dtype)
|
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
|
||||||
weight = (qweight, qzeros, scales, bits, groupsize)
|
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
@ -184,15 +179,20 @@ class Weights:
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize == "gptq":
|
if quantize in ["gptq", "awq"]:
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
if quantize == "gptq":
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
raise RuntimeError(
|
||||||
)
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
qzeros = torch.cat(
|
qzeros = torch.cat(
|
||||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
@ -200,32 +200,17 @@ class Weights:
|
|||||||
scales = torch.cat(
|
scales = torch.cat(
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
try:
|
||||||
torch.testing.assert_close(w2, w[0])
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
g_idx = w[0]
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
except RuntimeError:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
elif quantize == "awq":
|
|
||||||
try:
|
|
||||||
qweight = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
scales = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
|
||||||
weight = (qweight, qzeros, scales, bits, groupsize)
|
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
@ -248,7 +233,7 @@ class Weights:
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize in "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
|
|
||||||
@ -327,7 +312,7 @@ class Weights:
|
|||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, bits, groupsize)
|
weight = (qweight, qzeros, scales, None, bits, groupsize, None)
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user