adapt awq weights to exllama/gptq kernels

This commit is contained in:
IlyasMoutawwakil 2024-02-01 18:35:41 +00:00
parent 8665ab07ac
commit fb59c56215
2 changed files with 155 additions and 22 deletions

View File

@ -0,0 +1,98 @@
import torch
from typing import List
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
imatrix = imatrix.to(torch.int8)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
shifts = torch.arange(0, 32, 4, device=imatrix.device)
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = AWQ_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is AWQ_PACK_ORDER
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
return imatrix
def fast_awq_to_gptq(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros

View File

@ -7,6 +7,7 @@ from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.pack_utils import fast_awq_to_gptq
class Weights: class Weights:
@ -46,7 +47,6 @@ class Weights:
return self._handles[filename] return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str): def get_filename(self, tensor_name: str) -> (str, str):
names = [tensor_name] names = [tensor_name]
if self.prefix is not None: if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}" prefixed = f"{self.prefix}.{tensor_name}"
@ -157,12 +157,20 @@ class Weights:
qzeros = self._get_qweight(f"{prefix}.qzeros") qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales") scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
if quantize == "gptq": if quantize == "gptq" and self.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
else: else:
g_idx = None g_idx = None
bits, groupsize, _ = self._get_gptq_params() bits, groupsize, _, _ = self._get_gptq_params()
if quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
@ -204,7 +212,7 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
if quantize == "gptq": if quantize == "gptq" and self.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
@ -212,12 +220,20 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize, desc_act = self._get_gptq_params() bits, groupsize, desc_act, quant_method = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = ( use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
) )
if quantize == "gptq" and quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -243,7 +259,7 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True use_exllama = True
bits, groupsize, desc_act = self._get_gptq_params() bits, groupsize, desc_act, quant_method = self._get_gptq_params()
if bits != 4: if bits != 4:
use_exllama = False use_exllama = False
@ -252,8 +268,19 @@ class Weights:
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if self.process_group.size() > 1: if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None: if g_idx is not None:
if ( if (
not torch.equal( not torch.equal(
@ -269,13 +296,6 @@ class Weights:
# it would require to reorder input activations that are split unto several GPUs # it would require to reorder input activations that are split unto several GPUs
use_exllama = False use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
if use_exllama: if use_exllama:
@ -289,8 +309,6 @@ class Weights:
else: else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama and groupsize != -1: if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
@ -298,12 +316,19 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
if use_exllama: if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
if quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq": elif quantize == "awq":
bits, groupsize, _ = self._get_gptq_params() bits, groupsize, _, _ = self._get_gptq_params()
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -331,11 +356,12 @@ class Weights:
try: try:
bits = self.gptq_bits bits = self.gptq_bits
groupsize = self.gptq_groupsize groupsize = self.gptq_groupsize
quant_method = self.quant_method
desc_act = getattr(self, "gptq_desc_act", False) desc_act = getattr(self, "gptq_desc_act", False)
except Exception: except Exception:
raise e raise e
return bits, groupsize, desc_act return bits, groupsize, desc_act, quant_method
def _set_gptq_params(self, model_id, revision): def _set_gptq_params(self, model_id, revision):
filename = "config.json" filename = "config.json"
@ -350,7 +376,8 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"] self.gptq_desc_act = data["quantization_config"].get("desc_act", False)
self.quant_method = data["quantization_config"]["quant_method"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
@ -364,7 +391,11 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"] self.gptq_desc_act = data.get("desc_act", False)
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
else:
self.quant_method = "gptq"
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
@ -378,6 +409,10 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["w_bit"] self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"] self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"] self.gptq_desc_act = data.get("desc_act", False)
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
else:
self.quant_method = "gptq"
except Exception: except Exception:
pass pass