mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-24 00:30:18 +00:00
have a single gptq quantization type
This commit is contained in:
parent
a6e387404d
commit
4462854e1b
@ -19,7 +19,6 @@ mod env_runtime;
|
|||||||
enum Quantization {
|
enum Quantization {
|
||||||
Bitsandbytes,
|
Bitsandbytes,
|
||||||
Gptq,
|
Gptq,
|
||||||
Gptq_cuda,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Quantization {
|
impl std::fmt::Display for Quantization {
|
||||||
@ -32,9 +31,6 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::Gptq => {
|
Quantization::Gptq => {
|
||||||
write!(f, "gptq")
|
write!(f, "gptq")
|
||||||
}
|
}
|
||||||
Quantization::Gptq_cuda => {
|
|
||||||
write!(f, "gptq-cuda")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ app = typer.Typer()
|
|||||||
class Quantization(str, Enum):
|
class Quantization(str, Enum):
|
||||||
bitsandbytes = "bitsandbytes"
|
bitsandbytes = "bitsandbytes"
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
gptq_cuda = "gptq-cuda"
|
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
|
@ -286,7 +286,7 @@ def get_model(
|
|||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
if quantize in ["gptq", "gptq-cuda"]:
|
if quantize == "gptq":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,7 @@ from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
|||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
if config.quantize in ["gptq", "gptq-cuda"]:
|
if config.quantize == "gptq":
|
||||||
layer = _load_multi_mqa_gptq(
|
layer = _load_multi_mqa_gptq(
|
||||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
)
|
)
|
||||||
@ -79,7 +79,10 @@ def _load_multi_mqa_gptq(
|
|||||||
bits = weights.get_tensor("gptq_bits").item()
|
bits = weights.get_tensor("gptq_bits").item()
|
||||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
qweight = qweight.to(weights.device)
|
||||||
|
qzeros = qzeros.to(weights.device)
|
||||||
|
scales = scales.to(weights.device)
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
@ -93,7 +96,9 @@ def _load_multi_mqa_gptq(
|
|||||||
kv_tensor = slice_[-2 * head_size :]
|
kv_tensor = slice_[-2 * head_size :]
|
||||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
|
bias = bias.to(weights.device)
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||||
|
|
||||||
@ -166,7 +171,7 @@ def _load_multi_mqa(
|
|||||||
assert list(bias.shape) == [
|
assert list(bias.shape) == [
|
||||||
(num_heads + 2) * head_size
|
(num_heads + 2) * head_size
|
||||||
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
def load_col(config, prefix: str, weights, bias: bool):
|
def load_col(config, prefix: str, weights, bias: bool):
|
||||||
@ -181,23 +186,11 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
quantize = config.quantize
|
quantize = config.quantize
|
||||||
if quantize == "gptq-cuda" and weights.process_group.size() > 1:
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
||||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
|
||||||
|
|
||||||
act_order = True
|
|
||||||
if g_idx is not None:
|
|
||||||
if torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) or (g_idx == 0).all():
|
|
||||||
act_order = False
|
|
||||||
else:
|
|
||||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
|
||||||
quantize = "gptq"
|
|
||||||
|
|
||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
@ -209,12 +202,9 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
|
|
||||||
if quantize == "gptq-cuda" and not act_order:
|
|
||||||
weight[3] = None # remove g_idx to indicate to exllama that act-order is not used
|
|
||||||
|
|
||||||
return TensorParallelRowLinear(
|
return TensorParallelRowLinear(
|
||||||
get_linear(weight, bias, quantize, device=weights.device), process_group=weights.process_group
|
get_linear(weight, bias, quantize), process_group=weights.process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -480,7 +470,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Buffers need to be persistent to avoid any bug.
|
# Buffers need to be persistent to avoid any bug.
|
||||||
self.buffers = {}
|
self.buffers = {}
|
||||||
if config.quantize == "gptq-cuda":
|
if config.quantize == "gptq":
|
||||||
max_dq_buffer_size = 0
|
max_dq_buffer_size = 0
|
||||||
for name, submodule in self.named_modules():
|
for name, submodule in self.named_modules():
|
||||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||||
|
@ -392,29 +392,29 @@ def ext_q4_matmul(x, q4, q4_width):
|
|||||||
|
|
||||||
class Ex4bitLinear:
|
class Ex4bitLinear:
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device):
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
assert bits == 4
|
assert bits == 4
|
||||||
|
|
||||||
self.device = device
|
self.device = qweight.device
|
||||||
self.qweight = qweight.to(device)
|
self.qweight = qweight
|
||||||
self.qzeros = qzeros.to(device)
|
self.qzeros = qzeros
|
||||||
self.scales = scales.to(device)
|
self.scales = scales
|
||||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||||
self.bias = bias.to(device) if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
if self.g_idx is not None and (self.g_idx == 0).all():
|
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
||||||
self.empty_g_idx = True
|
self.empty_g_idx = True
|
||||||
self.g_idx = None
|
self.g_idx = None
|
||||||
|
|
||||||
assert device.type == "cuda"
|
assert self.device.type == "cuda"
|
||||||
assert device.index is not None
|
assert self.device.index is not None
|
||||||
|
|
||||||
self.q4 = ext_make_q4(
|
self.q4 = ext_make_q4(
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.qzeros,
|
self.qzeros,
|
||||||
self.scales,
|
self.scales,
|
||||||
self.g_idx,
|
self.g_idx,
|
||||||
device.index
|
self.device.index
|
||||||
)
|
)
|
||||||
|
|
||||||
self.height = qweight.shape[0] * 8
|
self.height = qweight.shape[0] * 8
|
||||||
|
@ -18,7 +18,7 @@ from accelerate import init_empty_weights
|
|||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from loguru import logger
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_layer_norm(cls, prefix, weights, eps):
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
@ -131,7 +131,7 @@ class Linear8bitLt(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize, device = None):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
@ -145,30 +145,24 @@ def get_linear(weight, bias, quantize, device = None):
|
|||||||
linear.bias = nn.Parameter(bias)
|
linear.bias = nn.Parameter(bias)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
|
||||||
except Exception:
|
except Exception:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||||
)
|
)
|
||||||
|
|
||||||
linear = QuantLinear(
|
if use_triton_kernel:
|
||||||
qweight,
|
linear = QuantLinear(
|
||||||
qzeros,
|
qweight,
|
||||||
scales,
|
qzeros,
|
||||||
g_idx,
|
scales,
|
||||||
bias,
|
g_idx,
|
||||||
bits,
|
bias,
|
||||||
groupsize,
|
bits,
|
||||||
)
|
groupsize,
|
||||||
elif quantize == "gptq-cuda":
|
|
||||||
try:
|
|
||||||
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
|
||||||
except Exception:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device)
|
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
@ -193,12 +187,12 @@ class TensorParallelHead(SuperLayer):
|
|||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
if config.quantize in ["gptq", "gptq-cuda"]:
|
if config.quantize == "gptq":
|
||||||
quantize = None
|
quantize = None
|
||||||
else:
|
else:
|
||||||
quantize = config.quantize
|
quantize = config.quantize
|
||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None, quantize=quantize, device=weights.device),
|
get_linear(weight, bias=None, quantize=quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -254,7 +248,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||||||
bias = torch.cat(b, dim=dim)
|
bias = torch.cat(b, dim=dim)
|
||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
linear = get_linear(weight, bias, config.quantize, device=weights.device)
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
return cls(linear)
|
return cls(linear)
|
||||||
|
|
||||||
|
|
||||||
@ -273,7 +267,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
else:
|
else:
|
||||||
bias = None
|
bias = None
|
||||||
return cls(
|
return cls(
|
||||||
get_linear(weight, bias, config.quantize, device=weights.device),
|
get_linear(weight, bias, config.quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -99,7 +99,7 @@ class Weights:
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize in ["gptq", "gptq-cuda"]:
|
if quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -114,20 +114,33 @@ class Weights:
|
|||||||
|
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize in ["gptq", "gptq-cuda"]:
|
if quantize == "gptq":
|
||||||
|
use_triton_kernel = False
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
|
if g_idx is not None:
|
||||||
|
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
use_triton_kernel = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||||
|
|
||||||
if quantize == "gptq":
|
if use_triton_kernel:
|
||||||
|
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||||
|
# Thus, each rank needs the full qzeros/scales.
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
@ -146,7 +159,7 @@ class Weights:
|
|||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
weight = [qweight, qzeros, scales, g_idx, bits, groupsize]
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user