mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Quantized weights were loaded in the `Weights` class, but this was getting quite unwieldy, where every higher level method to load weights was a long conditional to cover all the different quantizers. This change moves loading of quantized weights out of the `Weights` class. This is done by defining a simple `WeightsLoader` interface that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`, and `MarlinWeightsLoader`. These implementations are in the quantizers' respective modules. The `Weights` class provides the low-level load operations (such as loading tensors or sharded tensors), but delegates loads that need quantizer-specific weight processing to a loader. The loaders still use the low-level functionality provided by `Weights`. I initially tried making a hierarchy where a class like `GPTQWeights` would inherit from `Weights`. But it is not very flexible (e.g. does not work well with the new weight storage mock used in tests) and the implicit indirections made the code harder to follow.
404 lines
13 KiB
Python
404 lines
13 KiB
Python
from dataclasses import dataclass
|
|
from loguru import logger
|
|
import os
|
|
from typing import List, Optional, Union
|
|
from safetensors import SafetensorError
|
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
|
import torch
|
|
from text_generation_server.utils.import_utils import (
|
|
SYSTEM,
|
|
)
|
|
from text_generation_server.utils.log import log_once
|
|
|
|
|
|
@dataclass
|
|
class GPTQWeight:
|
|
qweight: torch.Tensor
|
|
qzeros: torch.Tensor
|
|
scales: torch.Tensor
|
|
g_idx: Optional[torch.Tensor]
|
|
bits: int
|
|
groupsize: int
|
|
use_exllama: bool
|
|
|
|
def __post_init__(self):
|
|
if self.scales.dtype == torch.float:
|
|
self.scales = self.scales.half()
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.qweight.device
|
|
|
|
|
|
try:
|
|
major, _minor = torch.cuda.get_device_capability()
|
|
except Exception:
|
|
major = 1
|
|
|
|
HAS_EXLLAMA = False
|
|
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
|
HAS_EXLLAMA = False
|
|
elif CAN_EXLLAMA:
|
|
try:
|
|
if V2:
|
|
from text_generation_server.layers.gptq.exllamav2 import (
|
|
QuantLinear as ExllamaQuantLinear,
|
|
create_exllama_buffers,
|
|
set_device,
|
|
)
|
|
|
|
HAS_EXLLAMA = "2"
|
|
else:
|
|
from text_generation_server.layers.gptq.exllama import (
|
|
Ex4bitLinear as ExllamaQuantLinear,
|
|
create_exllama_buffers,
|
|
set_device,
|
|
)
|
|
|
|
HAS_EXLLAMA = "1"
|
|
|
|
except ImportError:
|
|
pass
|
|
|
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|
|
|
|
|
class GPTQWeightsLoader(WeightsLoader):
|
|
"""
|
|
Loader for GPTQ- and AWQ-quantized weights.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
bits: int,
|
|
desc_act: bool,
|
|
groupsize: int,
|
|
quant_method: str,
|
|
quantize: str,
|
|
sym: bool,
|
|
):
|
|
self.bits = bits
|
|
self.desc_act = desc_act
|
|
self.groupsize = groupsize
|
|
self.quant_method = quant_method
|
|
self.quantize = quantize
|
|
self.sym = sym
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: Weights,
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
from text_generation_server.layers.marlin import (
|
|
can_use_gptq_marlin,
|
|
repack_gptq_for_marlin,
|
|
)
|
|
|
|
try:
|
|
qweight = weights.get_packed_sharded(
|
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
|
)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
|
)
|
|
scales = weights.get_packed_sharded(
|
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
|
)
|
|
scales = scales.to(dtype=weights.dtype)
|
|
|
|
self._get_gptq_params(weights)
|
|
if can_use_gptq_marlin(
|
|
bits=self.bits,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
quantize=self.quantize,
|
|
sym=self.sym,
|
|
):
|
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
sym=self.sym,
|
|
sharded_infeatures=False,
|
|
)
|
|
|
|
qzeros = weights.get_packed_sharded(
|
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
)
|
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
|
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
|
log_once(
|
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
|
)
|
|
from text_generation_server.layers.awq.conversion_utils import (
|
|
fast_awq_to_gptq,
|
|
)
|
|
|
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
|
g_idx = (
|
|
torch.arange(
|
|
qweight.shape[0] * (32 // self.bits),
|
|
device=qweight.device,
|
|
)
|
|
// self.groupsize
|
|
).to(dtype=torch.int32)
|
|
else:
|
|
g_idx = None
|
|
|
|
return GPTQWeight(
|
|
qweight=qweight,
|
|
qzeros=qzeros,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
groupsize=self.groupsize,
|
|
use_exllama=False,
|
|
)
|
|
|
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
from text_generation_server.layers.marlin import (
|
|
can_use_gptq_marlin,
|
|
repack_gptq_for_marlin,
|
|
)
|
|
|
|
try:
|
|
qweight = torch.cat(
|
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
|
)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
|
)
|
|
|
|
scales = torch.cat(
|
|
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
)
|
|
|
|
self._get_gptq_params(weights)
|
|
if can_use_gptq_marlin(
|
|
bits=self.bits,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
quantize=self.quantize,
|
|
sym=self.sym,
|
|
):
|
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
for w2 in w[1:]:
|
|
torch.testing.assert_close(w2, w[0])
|
|
g_idx = w[0]
|
|
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
sym=self.sym,
|
|
sharded_infeatures=False,
|
|
)
|
|
|
|
qzeros = torch.cat(
|
|
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
)
|
|
|
|
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
|
|
|
use_exllama = (
|
|
self.bits == 4
|
|
and HAS_EXLLAMA
|
|
and self.quantize == "gptq"
|
|
and not self.desc_act
|
|
)
|
|
|
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
for w2 in w[1:]:
|
|
torch.testing.assert_close(w2, w[0])
|
|
g_idx = w[0]
|
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
|
log_once(
|
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
|
)
|
|
from text_generation_server.layers.awq.conversion_utils import (
|
|
fast_awq_to_gptq,
|
|
)
|
|
|
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
|
if use_exllama:
|
|
g_idx = None
|
|
else:
|
|
g_idx = (
|
|
torch.arange(
|
|
qweight.shape[0] * (32 // self.bits),
|
|
device=qweight.device,
|
|
)
|
|
// self.groupsize
|
|
).to(dtype=torch.int32)
|
|
else:
|
|
g_idx = None
|
|
|
|
return GPTQWeight(
|
|
qweight=qweight,
|
|
qzeros=qzeros,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
groupsize=self.groupsize,
|
|
use_exllama=use_exllama,
|
|
)
|
|
|
|
def get_weights_row(self, weights: Weights, prefix: str):
|
|
from text_generation_server.layers.marlin import (
|
|
can_use_gptq_marlin,
|
|
repack_gptq_for_marlin,
|
|
)
|
|
|
|
self._get_gptq_params(weights)
|
|
if can_use_gptq_marlin(
|
|
bits=self.bits,
|
|
groupsize=self.groupsize,
|
|
quant_method=self.quant_method,
|
|
quantize=self.quantize,
|
|
sym=self.sym,
|
|
):
|
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
|
try:
|
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
)
|
|
|
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
if self.desc_act or self.groupsize == -1:
|
|
scales = weights.get_tensor(f"{prefix}.scales")
|
|
else:
|
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
|
|
|
sharded_in_features = weights.process_group.size() > 1
|
|
|
|
return repack_gptq_for_marlin(
|
|
qweight=qweight,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
desc_act=self.desc_act,
|
|
groupsize=self.groupsize,
|
|
sym=self.sym,
|
|
sharded_infeatures=sharded_in_features,
|
|
)
|
|
|
|
use_exllama = True
|
|
if self.bits != 4:
|
|
use_exllama = False
|
|
|
|
if self.desc_act:
|
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
|
use_exllama = False
|
|
|
|
try:
|
|
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
)
|
|
|
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
|
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
else:
|
|
g_idx = None
|
|
|
|
if weights.process_group.size() > 1:
|
|
if g_idx is not None:
|
|
if (
|
|
not torch.equal(
|
|
g_idx.cpu(),
|
|
torch.tensor(
|
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
|
dtype=torch.int32,
|
|
),
|
|
)
|
|
and not (g_idx == 0).all()
|
|
):
|
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
|
# it would require to reorder input activations that are split unto several GPUs
|
|
use_exllama = False
|
|
|
|
from text_generation_server.layers.gptq import (
|
|
HAS_EXLLAMA,
|
|
CAN_EXLLAMA,
|
|
GPTQWeight,
|
|
)
|
|
|
|
if use_exllama:
|
|
if not HAS_EXLLAMA:
|
|
if CAN_EXLLAMA:
|
|
log_once(
|
|
logger.warning,
|
|
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
|
)
|
|
use_exllama = False
|
|
else:
|
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
|
|
|
if use_exllama and self.groupsize != -1:
|
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
|
else:
|
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
|
scales = weights.get_tensor(f"{prefix}.scales")
|
|
|
|
if use_exllama and g_idx is not None:
|
|
g_idx = g_idx - g_idx[0]
|
|
|
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
|
log_once(
|
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
|
)
|
|
from text_generation_server.layers.awq.conversion_utils import (
|
|
fast_awq_to_gptq,
|
|
)
|
|
|
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
|
if use_exllama:
|
|
g_idx = None
|
|
else:
|
|
g_idx = (
|
|
torch.arange(
|
|
qweight.shape[0] * (32 // self.bits),
|
|
device=qweight.device,
|
|
)
|
|
// self.groupsize
|
|
).to(dtype=torch.int32)
|
|
|
|
return GPTQWeight(
|
|
qweight=qweight,
|
|
qzeros=qzeros,
|
|
scales=scales,
|
|
g_idx=g_idx,
|
|
bits=self.bits,
|
|
groupsize=self.groupsize,
|
|
use_exllama=use_exllama,
|
|
)
|
|
|
|
def _get_gptq_params(self, weights: Weights):
|
|
try:
|
|
self.bits = weights.get_tensor("gptq_bits").item()
|
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
|
self.desc_act = False
|
|
self.sym = False
|
|
self.quant_method = "gptq"
|
|
except (SafetensorError, RuntimeError) as e:
|
|
pass
|