mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
add multi-weight for GPTQ weight loader
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ce8978f9ea
commit
475f6e21bc
@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
use_exllama=use_exllama,
|
use_exllama=use_exllama,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
|
||||||
|
|
||||||
|
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user