mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
pass g_idx instead of changing triton kernel
This commit is contained in:
parent
bcdb02e41a
commit
994ed8e10d
@ -15,10 +15,9 @@ def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||
Returns:
|
||||
qmatrix (torch.Tensor): packed matrix of integers
|
||||
"""
|
||||
imatrix = imatrix.to(torch.int8)
|
||||
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
|
||||
|
||||
shifts = torch.arange(0, 32, 4, device=imatrix.device)
|
||||
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||
|
||||
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||
|
||||
if direction == "column":
|
||||
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
@ -182,7 +182,7 @@ try:
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1) & maxq # add 1 and avoid overflow
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
@ -251,17 +251,7 @@ class QuantLinear(nn.Module):
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
if g_idx is not None:
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"g_idx",
|
||||
torch.tensor(
|
||||
[i // groupsize for i in range(qweight.shape[0] * 32 // bits)],
|
||||
device=qweight.device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
|
@ -7,7 +7,6 @@ from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.pack_utils import fast_awq_to_gptq
|
||||
|
||||
|
||||
class Weights:
|
||||
@ -162,15 +161,22 @@ class Weights:
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
if quantize == "gptq" and quant_method == "awq":
|
||||
elif quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info,
|
||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||
"Converting AWQ weights to Exllama/GPTQ packing format, "
|
||||
"in order used with Exllama/GPTQ kernels.",
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = torch.zeros(
|
||||
(qweight.shape[0] * 32 // bits),
|
||||
dtype=torch.int32,
|
||||
device=qweight.device,
|
||||
)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
@ -220,8 +226,22 @@ class Weights:
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
else:
|
||||
g_idx = None
|
||||
elif quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info,
|
||||
"Converting AWQ weights to Exllama/GPTQ packing format, "
|
||||
"in order used with Exllama/GPTQ kernels.",
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = torch.zeros(
|
||||
(qweight.shape[0] * 32 // bits),
|
||||
dtype=torch.int32,
|
||||
device=qweight.device,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
@ -229,13 +249,6 @@ class Weights:
|
||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||
)
|
||||
|
||||
if quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info,
|
||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||
)
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
@ -279,7 +292,7 @@ class Weights:
|
||||
|
||||
if quant_method == "gptq":
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
else:
|
||||
elif quant_method == "awq":
|
||||
g_idx = None
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
@ -324,9 +337,19 @@ class Weights:
|
||||
if quant_method == "awq":
|
||||
log_once(
|
||||
logger.info,
|
||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
||||
"Converting AWQ weights to Exllama/GPTQ packing format, "
|
||||
"in order used with Exllama/GPTQ kernels.",
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = torch.zeros(
|
||||
(qweight.shape[0] * 32 // bits),
|
||||
dtype=torch.int32,
|
||||
device=qweight.device,
|
||||
)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
@ -353,13 +376,14 @@ class Weights:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
quant_method = "gptq"
|
||||
desc_act = False
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
quant_method = self.quant_method
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
quant_method = getattr(self, "quant_method", "gptq")
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
@ -378,8 +402,8 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
self.gptq_desc_act = data["quantization_config"].get("desc_act", False)
|
||||
self.quant_method = data["quantization_config"]["quant_method"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
@ -393,11 +417,11 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
self.gptq_desc_act = data.get("desc_act", False)
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
else:
|
||||
self.quant_method = "gptq"
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
@ -411,10 +435,10 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["w_bit"]
|
||||
self.gptq_groupsize = data["q_group_size"]
|
||||
self.gptq_desc_act = data.get("desc_act", False)
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
else:
|
||||
self.quant_method = "gptq"
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
except Exception:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user