have a single gptq quantization type

This commit is contained in:
Felix Marty 2023-07-12 15:43:20 +00:00
parent a6e387404d
commit 4462854e1b
7 changed files with 63 additions and 71 deletions

View File

@ -19,7 +19,6 @@ mod env_runtime;
enum Quantization {
Bitsandbytes,
Gptq,
Gptq_cuda,
}
impl std::fmt::Display for Quantization {
@ -32,9 +31,6 @@ impl std::fmt::Display for Quantization {
Quantization::Gptq => {
write!(f, "gptq")
}
Quantization::Gptq_cuda => {
write!(f, "gptq-cuda")
}
}
}
}

View File

@ -14,7 +14,6 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
gptq = "gptq"
gptq_cuda = "gptq-cuda"
class Dtype(str, Enum):

View File

@ -286,7 +286,7 @@ def get_model(
if sharded:
raise ValueError("sharded is not supported for AutoModel")
if quantize in ["gptq", "gptq-cuda"]:
if quantize == "gptq":
raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)

View File

@ -27,7 +27,7 @@ from custom_kernels.exllama import prepare_buffers, set_tuning_params
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if config.quantize in ["gptq", "gptq-cuda"]:
if config.quantize == "gptq":
layer = _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
@ -79,7 +79,10 @@ def _load_multi_mqa_gptq(
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
qweight = qweight.to(weights.device)
qzeros = qzeros.to(weights.device)
scales = scales.to(weights.device)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
@ -93,7 +96,9 @@ def _load_multi_mqa_gptq(
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
bias = bias.to(weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
@ -166,7 +171,7 @@ def _load_multi_mqa(
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_col(config, prefix: str, weights, bias: bool):
@ -181,23 +186,11 @@ def load_col(config, prefix: str, weights, bias: bool):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
quantize = config.quantize
if quantize == "gptq-cuda" and weights.process_group.size() > 1:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
groupsize = weights.get_tensor("gptq_groupsize").item()
act_order = True
if g_idx is not None:
if torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) or (g_idx == 0).all():
act_order = False
else:
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
quantize = "gptq"
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
@ -209,12 +202,9 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
if quantize == "gptq-cuda" and not act_order:
weight[3] = None # remove g_idx to indicate to exllama that act-order is not used
return TensorParallelRowLinear(
get_linear(weight, bias, quantize, device=weights.device), process_group=weights.process_group
get_linear(weight, bias, quantize), process_group=weights.process_group
)
@ -480,7 +470,7 @@ class FlashSantacoderForCausalLM(nn.Module):
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
if config.quantize == "gptq-cuda":
if config.quantize == "gptq":
max_dq_buffer_size = 0
for name, submodule in self.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):

View File

@ -392,29 +392,29 @@ def ext_q4_matmul(x, q4, q4_width):
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
assert bits == 4
self.device = device
self.qweight = qweight.to(device)
self.qzeros = qzeros.to(device)
self.scales = scales.to(device)
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias.to(device) if bias is not None else None
if self.g_idx is not None and (self.g_idx == 0).all():
self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
self.empty_g_idx = True
self.g_idx = None
assert device.type == "cuda"
assert device.index is not None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
device.index
self.device.index
)
self.height = qweight.shape[0] * 8

View File

@ -18,7 +18,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
from typing import Optional
from loguru import logger
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
@ -131,7 +131,7 @@ class Linear8bitLt(nn.Module):
return out
def get_linear(weight, bias, quantize, device = None):
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
elif quantize == "bitsandbytes":
@ -145,30 +145,24 @@ def get_linear(weight, bias, quantize, device = None):
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
bias,
bits,
groupsize,
)
elif quantize == "gptq-cuda":
try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
if use_triton_kernel:
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
bias,
bits,
groupsize,
)
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device)
else:
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
@ -193,12 +187,12 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize in ["gptq", "gptq-cuda"]:
if config.quantize == "gptq":
quantize = None
else:
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize, device=weights.device),
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
)
@ -254,7 +248,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize, device=weights.device)
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@ -273,7 +267,7 @@ class TensorParallelRowLinear(SuperLayer):
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize, device=weights.device),
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)

View File

@ -2,7 +2,7 @@ from pathlib import Path
from typing import List, Dict, Optional
from safetensors import safe_open
import torch
from loguru import logger
class Weights:
def __init__(
@ -99,7 +99,7 @@ class Weights:
return tensor
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize in ["gptq", "gptq-cuda"]:
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
except RuntimeError:
@ -114,20 +114,33 @@ class Weights:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize in ["gptq", "gptq-cuda"]:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_triton_kernel = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
groupsize = self.get_tensor("gptq_groupsize").item()
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_triton_kernel = True
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
if quantize == "gptq":
if use_triton_kernel:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
@ -146,7 +159,7 @@ class Weights:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight