generate g_idx only for triton kernel

This commit is contained in:
IlyasMoutawwakil 2024-02-08 16:05:09 +01:00
parent e29fb799cb
commit bc157af9b0

View File

@ -220,6 +220,12 @@ class Weights:
bits, groupsize, desc_act, quant_method = self._get_gptq_params() bits, groupsize, desc_act, quant_method = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
)
if quantize == "gptq" and quant_method == "gptq": if quantize == "gptq" and quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: for w2 in w[1:]:
@ -234,19 +240,18 @@ class Weights:
) )
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = ( if use_exllama:
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) g_idx = None
// groupsize else:
).to(dtype=torch.int32) g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
)
// groupsize
).to(dtype=torch.int32)
else: else:
g_idx = None g_idx = None
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -341,10 +346,15 @@ class Weights:
) )
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = ( if use_exllama:
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) g_idx = None
// groupsize else:
).to(dtype=torch.int32) g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
)
// groupsize
).to(dtype=torch.int32)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq": elif quantize == "awq":
@ -371,8 +381,8 @@ class Weights:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
quant_method = "gptq"
desc_act = False desc_act = False
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
bits = self.gptq_bits bits = self.gptq_bits
@ -397,8 +407,8 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_desc_act = data["quantization_config"]["desc_act"] self.gptq_desc_act = data["quantization_config"]["desc_act"]
self.quant_method = data["quantization_config"]["quant_method"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
@ -412,11 +422,9 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq" self.quant_method = "awq"
else:
self.quant_method = "gptq"
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
@ -430,10 +438,8 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["w_bit"] self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"] self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq" self.quant_method = "awq"
else:
self.quant_method = "gptq"
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
pass pass