pass g_idx instead of changing triton kernel

This commit is contained in:
IlyasMoutawwakil 2024-02-02 14:34:15 +01:00 committed by Nicolas Patry
parent 646ab28285
commit bbe5bedea5
3 changed files with 50 additions and 37 deletions

View File

@ -15,10 +15,9 @@ def pack(imatrix: torch.Tensor, direction: str = "column"):
Returns: Returns:
qmatrix (torch.Tensor): packed matrix of integers qmatrix (torch.Tensor): packed matrix of integers
""" """
imatrix = imatrix.to(torch.int8) shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
shifts = torch.arange(0, 32, 4, device=imatrix.device)
if direction == "column": if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))

View File

@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # add 1 and avoid overflow zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
@ -251,17 +251,7 @@ class QuantLinear(nn.Module):
self.register_buffer("qweight", qweight) self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros) self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales) self.register_buffer("scales", scales)
if g_idx is not None: self.register_buffer("g_idx", g_idx)
self.register_buffer("g_idx", g_idx)
else:
self.register_buffer(
"g_idx",
torch.tensor(
[i // groupsize for i in range(qweight.shape[0] * 32 // bits)],
device=qweight.device,
dtype=torch.int32,
),
)
if bias is not None: if bias is not None:
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
else: else:

View File

@ -7,7 +7,6 @@ from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.pack_utils import fast_awq_to_gptq
class Weights: class Weights:
@ -162,15 +161,22 @@ class Weights:
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
else: elif quantize == "gptq" and quant_method == "awq":
g_idx = None
if quantize == "gptq" and quant_method == "awq":
log_once( log_once(
logger.info, logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", "Converting AWQ weights to Exllama/GPTQ packing format, "
"in order used with Exllama/GPTQ kernels.",
) )
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = torch.zeros(
(qweight.shape[0] * 32 // bits),
dtype=torch.int32,
device=qweight.device,
)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else: else:
@ -220,8 +226,22 @@ class Weights:
for w2 in w[1:]: for w2 in w[1:]:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
else: elif quantize == "gptq" and quant_method == "awq":
g_idx = None log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ packing format, "
"in order used with Exllama/GPTQ kernels.",
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = torch.zeros(
(qweight.shape[0] * 32 // bits),
dtype=torch.int32,
device=qweight.device,
)
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
@ -229,13 +249,6 @@ class Weights:
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
) )
if quantize == "gptq" and quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -279,7 +292,7 @@ class Weights:
if quant_method == "gptq": if quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
else: elif quant_method == "awq":
g_idx = None g_idx = None
if self.process_group.size() > 1: if self.process_group.size() > 1:
@ -324,9 +337,19 @@ class Weights:
if quant_method == "awq": if quant_method == "awq":
log_once( log_once(
logger.info, logger.info,
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", "Converting AWQ weights to Exllama/GPTQ packing format, "
"in order used with Exllama/GPTQ kernels.",
) )
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = torch.zeros(
(qweight.shape[0] * 32 // bits),
dtype=torch.int32,
device=qweight.device,
)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq": elif quantize == "awq":
@ -353,13 +376,14 @@ class Weights:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
quant_method = "gptq"
desc_act = False desc_act = False
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
bits = self.gptq_bits bits = self.gptq_bits
groupsize = self.gptq_groupsize groupsize = self.gptq_groupsize
quant_method = self.quant_method
desc_act = getattr(self, "gptq_desc_act", False) desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
except Exception: except Exception:
raise e raise e
@ -378,8 +402,8 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"].get("desc_act", False)
self.quant_method = data["quantization_config"]["quant_method"] self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
@ -393,11 +417,11 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data.get("desc_act", False)
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq" self.quant_method = "awq"
else: else:
self.quant_method = "gptq" self.quant_method = "gptq"
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
@ -411,10 +435,10 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["w_bit"] self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"] self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data.get("desc_act", False)
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq" self.quant_method = "awq"
else: else:
self.quant_method = "gptq" self.quant_method = "gptq"
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
pass pass