log message

This commit is contained in:
IlyasMoutawwakil 2024-02-05 09:26:47 +01:00 committed by Nicolas Patry
parent 76834c9989
commit 2629193efa

View File

@ -163,20 +163,17 @@ class Weights:
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ packing format, "
"in order used with Exllama/GPTQ kernels.",
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = torch.zeros(
(qweight.shape[0] * 32 // bits),
dtype=torch.int32,
device=qweight.device,
)
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
@ -230,20 +227,17 @@ class Weights:
g_idx = w[0]
elif quantize == "gptq" and quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ packing format, "
"in order used with Exllama/GPTQ kernels.",
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = torch.zeros(
(qweight.shape[0] * 32 // bits),
dtype=torch.int32,
device=qweight.device,
)
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
@ -340,20 +334,17 @@ class Weights:
if quant_method == "awq":
log_once(
logger.info,
"Converting AWQ weights to Exllama/GPTQ packing format, "
"in order used with Exllama/GPTQ kernels.",
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = torch.zeros(
(qweight.shape[0] * 32 // bits),
dtype=torch.int32,
device=qweight.device,
)
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":