mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
adapt awq weights to exllama/gptq kernels
This commit is contained in:
parent
8665ab07ac
commit
fb59c56215
98
server/text_generation_server/utils/pack_utils.py
Normal file
98
server/text_generation_server/utils/pack_utils.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||||
|
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of packing, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
|
"""
|
||||||
|
imatrix = imatrix.to(torch.int8)
|
||||||
|
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
|
||||||
|
|
||||||
|
shifts = torch.arange(0, 32, 4, device=imatrix.device)
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||||
|
|
||||||
|
qmatrix = qmatrix.to(torch.int32)
|
||||||
|
|
||||||
|
return qmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
qmatrix (torch.Tensor): matrix of packed integers
|
||||||
|
direction (str): direction of unpacking, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, :, None], shifts[None, None, :]
|
||||||
|
).view(qmatrix.shape[0], -1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, None, :], shifts[None, :, None]
|
||||||
|
).view(-1, qmatrix.shape[-1])
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def apply_order(
|
||||||
|
imatrix: torch.Tensor,
|
||||||
|
direction: str = "column",
|
||||||
|
order: List[int] = AWQ_PACK_ORDER,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies the order to a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of applying order, either "column" or "row"
|
||||||
|
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def fast_awq_to_gptq(qweight, qzeros):
|
||||||
|
# awq uses column packing for both weights and zeros
|
||||||
|
izeros = unpack(qzeros, direction="column")
|
||||||
|
iweights = unpack(qweight, direction="column")
|
||||||
|
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||||
|
izeros = izeros - 1
|
||||||
|
# exllama uses row packing for weights and column packing for zeros
|
||||||
|
qzeros = pack(izeros, direction="column")
|
||||||
|
qweight = pack(iweights, direction="row")
|
||||||
|
|
||||||
|
return qweight, qzeros
|
@ -7,6 +7,7 @@ from loguru import logger
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.pack_utils import fast_awq_to_gptq
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
@ -46,7 +47,6 @@ class Weights:
|
|||||||
return self._handles[filename]
|
return self._handles[filename]
|
||||||
|
|
||||||
def get_filename(self, tensor_name: str) -> (str, str):
|
def get_filename(self, tensor_name: str) -> (str, str):
|
||||||
|
|
||||||
names = [tensor_name]
|
names = [tensor_name]
|
||||||
if self.prefix is not None:
|
if self.prefix is not None:
|
||||||
prefixed = f"{self.prefix}.{tensor_name}"
|
prefixed = f"{self.prefix}.{tensor_name}"
|
||||||
@ -157,12 +157,20 @@ class Weights:
|
|||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
if quantize == "gptq":
|
if quantize == "gptq" and self.quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize, _ = self._get_gptq_params()
|
bits, groupsize, _, _ = self._get_gptq_params()
|
||||||
|
|
||||||
|
if quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||||
|
)
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
@ -204,7 +212,7 @@ class Weights:
|
|||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if quantize == "gptq":
|
if quantize == "gptq" and self.quant_method == "gptq":
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
@ -212,12 +220,20 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize, desc_act = self._get_gptq_params()
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = (
|
use_exllama = (
|
||||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if quantize == "gptq" and quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||||
|
)
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
@ -243,7 +259,7 @@ class Weights:
|
|||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
bits, groupsize, desc_act = self._get_gptq_params()
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
if bits != 4:
|
if bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
@ -252,8 +268,19 @@ class Weights:
|
|||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if quant_method == "gptq":
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
if (
|
if (
|
||||||
not torch.equal(
|
not torch.equal(
|
||||||
@ -269,13 +296,6 @@ class Weights:
|
|||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
try:
|
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
@ -289,8 +309,6 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if use_exllama and groupsize != -1:
|
if use_exllama and groupsize != -1:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
@ -298,12 +316,19 @@ class Weights:
|
|||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama and g_idx is not None:
|
||||||
g_idx = g_idx - g_idx[0]
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
|
if quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||||
|
)
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
bits, groupsize, _ = self._get_gptq_params()
|
bits, groupsize, _, _ = self._get_gptq_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
@ -331,11 +356,12 @@ class Weights:
|
|||||||
try:
|
try:
|
||||||
bits = self.gptq_bits
|
bits = self.gptq_bits
|
||||||
groupsize = self.gptq_groupsize
|
groupsize = self.gptq_groupsize
|
||||||
|
quant_method = self.quant_method
|
||||||
desc_act = getattr(self, "gptq_desc_act", False)
|
desc_act = getattr(self, "gptq_desc_act", False)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return bits, groupsize, desc_act
|
return bits, groupsize, desc_act, quant_method
|
||||||
|
|
||||||
def _set_gptq_params(self, model_id, revision):
|
def _set_gptq_params(self, model_id, revision):
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
@ -350,7 +376,8 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
self.gptq_desc_act = data["quantization_config"].get("desc_act", False)
|
||||||
|
self.quant_method = data["quantization_config"]["quant_method"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
@ -364,7 +391,11 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["bits"]
|
self.gptq_bits = data["bits"]
|
||||||
self.gptq_groupsize = data["group_size"]
|
self.gptq_groupsize = data["group_size"]
|
||||||
self.gptq_desc_act = data["desc_act"]
|
self.gptq_desc_act = data.get("desc_act", False)
|
||||||
|
if "version" in data and data["version"] == "GEMM":
|
||||||
|
self.quant_method = "awq"
|
||||||
|
else:
|
||||||
|
self.quant_method = "gptq"
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
@ -378,6 +409,10 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["w_bit"]
|
self.gptq_bits = data["w_bit"]
|
||||||
self.gptq_groupsize = data["q_group_size"]
|
self.gptq_groupsize = data["q_group_size"]
|
||||||
self.gptq_desc_act = data["desc_act"]
|
self.gptq_desc_act = data.get("desc_act", False)
|
||||||
|
if "version" in data and data["version"] == "GEMM":
|
||||||
|
self.quant_method = "awq"
|
||||||
|
else:
|
||||||
|
self.quant_method = "gptq"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user