mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
pass g_idx instead of changing triton kernel
This commit is contained in:
parent
646ab28285
commit
bbe5bedea5
@ -15,10 +15,9 @@ def pack(imatrix: torch.Tensor, direction: str = "column"):
|
|||||||
Returns:
|
Returns:
|
||||||
qmatrix (torch.Tensor): packed matrix of integers
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
"""
|
"""
|
||||||
imatrix = imatrix.to(torch.int8)
|
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||||
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
shifts = torch.arange(0, 32, 4, device=imatrix.device)
|
|
||||||
|
|
||||||
if direction == "column":
|
if direction == "column":
|
||||||
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
@ -182,7 +182,7 @@ try:
|
|||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
zeros = (zeros + 1) & maxq # add 1 and avoid overflow
|
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
@ -251,17 +251,7 @@ class QuantLinear(nn.Module):
|
|||||||
self.register_buffer("qweight", qweight)
|
self.register_buffer("qweight", qweight)
|
||||||
self.register_buffer("qzeros", qzeros)
|
self.register_buffer("qzeros", qzeros)
|
||||||
self.register_buffer("scales", scales)
|
self.register_buffer("scales", scales)
|
||||||
if g_idx is not None:
|
self.register_buffer("g_idx", g_idx)
|
||||||
self.register_buffer("g_idx", g_idx)
|
|
||||||
else:
|
|
||||||
self.register_buffer(
|
|
||||||
"g_idx",
|
|
||||||
torch.tensor(
|
|
||||||
[i // groupsize for i in range(qweight.shape[0] * 32 // bits)],
|
|
||||||
device=qweight.device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer("bias", bias)
|
self.register_buffer("bias", bias)
|
||||||
else:
|
else:
|
||||||
|
@ -7,7 +7,6 @@ from loguru import logger
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.pack_utils import fast_awq_to_gptq
|
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
@ -162,15 +161,22 @@ class Weights:
|
|||||||
|
|
||||||
if quantize == "gptq" and quant_method == "gptq":
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
else:
|
elif quantize == "gptq" and quant_method == "awq":
|
||||||
g_idx = None
|
|
||||||
|
|
||||||
if quantize == "gptq" and quant_method == "awq":
|
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
"Converting AWQ weights to Exllama/GPTQ packing format, "
|
||||||
|
"in order used with Exllama/GPTQ kernels.",
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = torch.zeros(
|
||||||
|
(qweight.shape[0] * 32 // bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
@ -220,8 +226,22 @@ class Weights:
|
|||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
g_idx = w[0]
|
g_idx = w[0]
|
||||||
else:
|
elif quantize == "gptq" and quant_method == "awq":
|
||||||
g_idx = None
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Converting AWQ weights to Exllama/GPTQ packing format, "
|
||||||
|
"in order used with Exllama/GPTQ kernels.",
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = torch.zeros(
|
||||||
|
(qweight.shape[0] * 32 // bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
@ -229,13 +249,6 @@ class Weights:
|
|||||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||||
)
|
)
|
||||||
|
|
||||||
if quantize == "gptq" and quant_method == "awq":
|
|
||||||
log_once(
|
|
||||||
logger.info,
|
|
||||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
|
||||||
)
|
|
||||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
@ -279,7 +292,7 @@ class Weights:
|
|||||||
|
|
||||||
if quant_method == "gptq":
|
if quant_method == "gptq":
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
else:
|
elif quant_method == "awq":
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
@ -324,9 +337,19 @@ class Weights:
|
|||||||
if quant_method == "awq":
|
if quant_method == "awq":
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels",
|
"Converting AWQ weights to Exllama/GPTQ packing format, "
|
||||||
|
"in order used with Exllama/GPTQ kernels.",
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = torch.zeros(
|
||||||
|
(qweight.shape[0] * 32 // bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
@ -353,13 +376,14 @@ class Weights:
|
|||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
quant_method = "gptq"
|
||||||
desc_act = False
|
desc_act = False
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
try:
|
try:
|
||||||
bits = self.gptq_bits
|
bits = self.gptq_bits
|
||||||
groupsize = self.gptq_groupsize
|
groupsize = self.gptq_groupsize
|
||||||
quant_method = self.quant_method
|
|
||||||
desc_act = getattr(self, "gptq_desc_act", False)
|
desc_act = getattr(self, "gptq_desc_act", False)
|
||||||
|
quant_method = getattr(self, "quant_method", "gptq")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -378,8 +402,8 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
self.gptq_desc_act = data["quantization_config"].get("desc_act", False)
|
|
||||||
self.quant_method = data["quantization_config"]["quant_method"]
|
self.quant_method = data["quantization_config"]["quant_method"]
|
||||||
|
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
@ -393,11 +417,11 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["bits"]
|
self.gptq_bits = data["bits"]
|
||||||
self.gptq_groupsize = data["group_size"]
|
self.gptq_groupsize = data["group_size"]
|
||||||
self.gptq_desc_act = data.get("desc_act", False)
|
|
||||||
if "version" in data and data["version"] == "GEMM":
|
if "version" in data and data["version"] == "GEMM":
|
||||||
self.quant_method = "awq"
|
self.quant_method = "awq"
|
||||||
else:
|
else:
|
||||||
self.quant_method = "gptq"
|
self.quant_method = "gptq"
|
||||||
|
self.gptq_desc_act = data["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
@ -411,10 +435,10 @@ class Weights:
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.gptq_bits = data["w_bit"]
|
self.gptq_bits = data["w_bit"]
|
||||||
self.gptq_groupsize = data["q_group_size"]
|
self.gptq_groupsize = data["q_group_size"]
|
||||||
self.gptq_desc_act = data.get("desc_act", False)
|
|
||||||
if "version" in data and data["version"] == "GEMM":
|
if "version" in data and data["version"] == "GEMM":
|
||||||
self.quant_method = "awq"
|
self.quant_method = "awq"
|
||||||
else:
|
else:
|
||||||
self.quant_method = "gptq"
|
self.quant_method = "gptq"
|
||||||
|
self.gptq_desc_act = data["desc_act"]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user