mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
add multi-weight for GPTQ weight loader
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ce8978f9ea
commit
475f6e21bc
@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
|
||||
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
|
||||
|
||||
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_awq_kernel=self.quantize == "awq",
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
self._get_gptq_params(weights)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user