remove unused quantization code and enable awq/gptq int4

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-21 18:28:58 -07:00
parent fdf0733f56
commit 9914ffe1f1
23 changed files with 291 additions and 2373 deletions

View File

@ -7,10 +7,6 @@ from text_generation_server.utils.weights import (
)
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
from text_generation_server.layers.marlin.marlin import (
MarlinWeight,
MarlinWeightsLoader,
)
from types import SimpleNamespace
from typing import List, Optional, Dict, Union
from pathlib import Path
@ -40,11 +36,6 @@ def gptq_weights_loader_awq():
)
@pytest.fixture
def marlin_weights_loader():
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
dummy_file_system = {
"test_weights": {
"layer.0.weight": torch.tensor(
@ -125,10 +116,6 @@ dummy_file_system = {
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
},
"test_get_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
},
"test_get_weights_row_gptq": {
"weight.qweight": torch.tensor(
[
@ -273,18 +260,6 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
"test_get_multi_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
"test_get_weights_col_packed_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
}
@ -718,33 +693,6 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
w = weights.get_weights_col(
prefix=prefix,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_weights_col_packed
@ -868,36 +816,6 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_weights_col_packed_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
w = weights.get_multi_weights_col(
prefixes=[prefix],
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
print(expected_weight)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_col
@ -1004,34 +922,6 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_multi_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
w = weights.get_multi_weights_col(
prefixes=[prefix],
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_weights_row
@ -1148,30 +1038,3 @@ def test_get_weights_row_gptq(gptq_weights_loader):
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_row_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_weights_row_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
w = weights.get_weights_row(
prefix=prefix,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"

View File

@ -16,15 +16,9 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"
marlin = "marlin"
class Dtype(str, Enum):

View File

@ -1,19 +1,93 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
from typing import Optional
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
try:
import habana_frameworks.torch.hpu # noqa: F401
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def unpack_weight_and_zeros(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
return iweight, izeros
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros(
(normal.shape[0], normal.shape[1] // 32 * bits),
dtype=torch.int32,
device=input.device,
)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class WQLinear(nn.Module):
@ -38,12 +112,23 @@ class WQLinear(nn.Module):
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self._preprocessing()
def _preprocessing(self):
device = self.qweight.device
weight, zeros = unpack_weight_and_zeros(
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
)
self.qweight = pack_tensor(weight).to(device)
self.qzeros = pack_tensor(zeros).to(device)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
x = x.reshape(-1, x.shape[-1])
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
outputs = torch.matmul(x, weights)
outputs = outputs + self.bias if self.bias is not None else outputs
outputs = outputs.reshape(out_shape)
return outputs

View File

@ -1,3 +0,0 @@
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]

View File

@ -1,196 +0,0 @@
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
WNA16Int24Loader,
)
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
elif (
format == CompressionFormat.pack_quantized.value
and weights is not None
and weights.type == QuantizationType.INT
and weights.num_bits in (4, 8)
):
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
return WNA16IntLoader(weights)
elif (
format == CompressionFormat.marlin_24.value
and weights is not None
and weights.type == QuantizationType.INT
and weights.num_bits in (4, 8)
):
return WNA16Int24Loader(weights)
elif (
format
in {
CompressionFormat.int_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.INT
and weights.num_bits == 8
):
return W8A8IntLoader(input_args=input_activations, weight_args=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]

View File

@ -1,239 +0,0 @@
from typing import List, Optional, Union, TypeVar
from dataclasses import dataclass
from loguru import logger
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
quantization = None
class W8A8IntLoader(WeightsLoader):
"""
Loader for w8a8 integer compressed-tensors parameters.
"""
def __init__(
self,
*,
input_args: Optional[QuantizationArgs],
weight_args: QuantizationArgs,
):
if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8:
raise ValueError(
f"{type(self).__name__} only supports w8a8 int checkpoints"
)
if not weight_args.symmetric:
raise ValueError("Checkpoints with asymmetric weights are not supported")
self.load_weight_scale = not weight_args.dynamic
if input_args is not None:
self.input_symmetric = input_args.symmetric
if not input_args.dynamic:
log_once(
logger.warning,
"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).",
)
else:
self.input_symmetric = True
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
def symmetric_to_str(symmetric):
return "symmetric" if symmetric else "asymmetric"
return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
return Int8Weight(
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1)
return Int8Weight(
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes
]
shapes = [x.shape for x in w]
w = torch.cat(w, dim=dim)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
return Int8Weight(
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
return Int8Weight(
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
)
OtherT = TypeVar("OtherT")
def _get_tensor_or_else(
weights: Weights, prefix: str, other: OtherT
) -> Union[torch.Tensor, OtherT]:
# Even if a checkpoint uses e.g. zero-points, they can be elided:
# https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105
if weights.has_tensor(prefix):
return weights.get_tensor(prefix, to_dtype=False)
else:
return other
@dataclass
class Int8Weight(Weight):
input_symmetric: bool
weight: torch.Tensor
weight_scale: Optional[torch.Tensor]
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
assert quantization is not None
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
return W8A8IntLinear(
bias=bias,
input_symmetric=self.input_symmetric,
weight=qweight,
weight_scale=weight_scale,
)
else:
return W8A8IntLinear(
bias=bias,
input_symmetric=self.input_symmetric,
weight=self.weight,
weight_scale=self.weight_scale,
)
class W8A8IntLinear(torch.nn.Module):
def __init__(
self,
*,
bias: Optional[torch.Tensor],
input_symmetric: bool,
weight: torch.Tensor,
weight_scale: torch.Tensor,
):
super().__init__()
weight_scale = weight_scale.to(torch.float32)
self.bias = bias
self.input_symmetric = input_symmetric
# cutlass kernels require transposed weights.
self.weight = weight.t()
self.weight_scale = weight_scale
if input_symmetric:
self.zero_point_adj = None
else:
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp
self.zero_point_adj = self.weight.sum(
dim=0, keepdim=True, dtype=torch.int32
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert quantization is not None
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
input=input,
scale=None,
azp=None,
symmetric=self.input_symmetric,
)
if self.input_symmetric:
return quantization.cutlass_scaled_mm(
a=qinput,
b=self.weight,
scale_a=input_scale,
scale_b=self.weight_scale,
out_dtype=input.dtype,
bias=self.bias,
)
else:
assert (
self.zero_point_adj is not None
and input_scale is not None
and (self.input_symmetric or input_zero_point is not None)
)
return quantization.cutlass_scaled_mm_azp(
a=qinput,
b=self.weight,
scale_a=input_scale,
scale_b=self.weight_scale,
out_dtype=input.dtype,
azp_adj=self.zero_point_adj,
azp=input_zero_point,
bias=self.bias,
)

View File

@ -1,168 +0,0 @@
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)

View File

@ -1,188 +0,0 @@
from typing import List, Union
import torch
from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs
from loguru import logger
from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
class WNA16IntLoader(WeightsLoader):
"""
Loader for W4A16/W8A16 INT compressed-tensors parameters.
"""
def __init__(self, weights: QuantizationArgs):
self.weights = weights
self.desc_act = self.weights.actorder == ActivationOrdering.GROUP
self.groupsize = (
-1 if self.weights.group_size is None else self.weights.group_size
)
def __str__(self) -> str:
quantization_type = f"W{self.weights.num_bits}A16"
return f"{self.__class__.__name__} ({quantization_type})"
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t()
except RuntimeError:
raise RuntimeError(
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
)
zero_point = None
if not self.weights.symmetric:
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
g_idx = None
if self.desc_act:
g_idx = weights.get_tensor(f"{prefix}.weight_g_idx")
scales = weights.get_tensor(f"{prefix}.weight.scales").t()
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
weight_packed = weights.get_packed_sharded(
f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes
).t()
except RuntimeError:
raise RuntimeError(
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
)
scales = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
).t()
scales = scales.to(dtype=weights.dtype)
zero_point = None
if not self.weights.symmetric:
zero_point = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=0, block_sizes=block_sizes
).t()
g_idx = None
if self.desc_act:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
weight_packed = torch.cat(
[
weights.get_sharded(f"{p}.weight_packed", dim=0).t()
for p in prefixes
],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes],
dim=1,
)
zero_point = None
if not self.weights.symmetric:
zero_point = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1
).t()
g_idx = None
if self.desc_act:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t()
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
zero_point = None
if not self.weights.symmetric:
if self.desc_act or self.groupsize == -1:
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
else:
zero_point = weights.get_sharded(
f"{prefix}.weight_zero_point", dim=1
).t()
g_idx = None
if self.desc_act:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.weight_scale").t()
else:
scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t()
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=weight_packed.contiguous(),
scales=scales,
qzeros=zero_point,
g_idx=g_idx,
bits=self.weights.num_bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method="compressed-tensors",
sym=self.weights.symmetric,
sharded_infeatures=sharded_in_features,
)

View File

@ -1,101 +0,0 @@
from typing import List, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight
from text_generation_server.utils.weights import Weights, WeightsLoader
class WNA16Int24Loader(WeightsLoader):
"""
Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.
"""
def __init__(self, weight_args: QuantizationArgs):
super().__init__()
if weight_args.type != QuantizationType.INT:
raise ValueError(
f"{type(self).__name__} only supports wNa8 int checkpoints"
)
if weight_args.strategy == "group" and weight_args.group_size is None:
raise ValueError("`group_size` must be set when `actorder` is `group`")
self.bits = weight_args.num_bits
self.group_size = weight_args.group_size
def __str__(self) -> str:
quantization_type = f"W{self.bits}A16 2:4 sparsity"
return f"{self.__class__.__name__} ({quantization_type})"
def get_weights(self, weights: Weights, prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
weight_packed = weights.get_tensor(f"{prefix}.weight_packed")
meta = weights.get_tensor(f"{prefix}.meta")
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
weight_packed = weights.get_packed_sharded(
f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes
)
meta = weights.get_packed_sharded(
f"{prefix}.meta", dim=1, block_sizes=block_sizes
)
scale_packed = weights.get_packed_sharded(
f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes
)
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
weight_packed = torch.cat(
[weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1
)
meta = torch.cat(
[weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1
)
scale_packed = torch.cat(
[weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1
)
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)
def get_weights_row(self, weights: Weights, prefix: str):
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0)
meta = weights.get_sharded(f"{prefix}.meta", dim=0)
if self.group_size is None:
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
else:
scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0)
return GPTQMarlin24Weight(
weight_packed=weight_packed,
meta=meta,
scale_packed=scale_packed,
bits=self.bits,
)

View File

@ -7,7 +7,7 @@ from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
QuantLinear = None
from .hpu import QuantLinear
@dataclass
@ -215,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader):
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
@ -362,6 +355,3 @@ class GPTQWeightsLoader(WeightsLoader):
else False
)
self.quant_method = "gptq"
HAS_EXLLAMA = False

View File

@ -1,125 +1,181 @@
import math
import numpy as np
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.infeatures,
self.outfeatures,
bias=self.bias,
group_size=self.groupsize,
g_idx=g_idx,
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
return out.reshape(out_shape)
import math
import numpy as np
import torch
import torch.nn as nn
try:
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
self.scales.dtype
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
return weight
def _preprocessing(self):
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to("hpu")
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.group_size for i in range(columns)]
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32)
assert torch.equal(
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
zeros = self.unpack_zeros_from_cuda_old_format().cpu()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to("hpu")
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

View File

@ -1,15 +0,0 @@
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeightsLoader,
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
__all__ = [
"GPTQMarlinFP8Linear",
"GPTQMarlinWeightsLoader",
"MarlinWeightsLoader",
"can_use_gptq_marlin",
"repack_gptq_for_marlin",
]

View File

@ -1,141 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
permute_scales,
)
quantization = None
MARLIN_TILE_SIZE = 16
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
qweight: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert quantization is not None
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scales = fp8_quantize(weight)
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
dtype: torch.dtype,
**kwargs,
):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = quantization.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = permute_scales(scales)
return repacked, scales

View File

@ -1,465 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
marlin_zero_points,
permute_scales,
unpack_cols,
)
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
quantization = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return False
class GPTQMarlinWeightsLoader(WeightsLoader):
"""
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
def _get_gptq_params(self, weights: Weights):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights.has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
@dataclass
class GPTQMarlinWeight(Weight):
"""
Repacked GPTQ Marlin weights.
"""
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
bits: int
is_full_k: bool
def __post_init__(self):
assert self.qweight.dtype == torch.int32
assert self.scales.dtype in (torch.float16, torch.bfloat16)
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor,
g_idx: Optional[torch.Tensor],
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert quantization is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"):
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits
in_features = qweight.shape[0]
out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq":
repacked = quantization.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = quantization.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
bits=bits,
is_full_k=is_full_k,
)
class GPTQMarlinLinear(nn.Module):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def __init__(
self,
*,
weight: GPTQMarlinWeight,
bias: Optional[torch.Tensor],
):
super().__init__()
_check_marlin_kernels()
assert quantization is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
if weight.bits not in (4, 8):
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
if weight.qzeros.numel() > 0:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4
else:
self.quant_type = quantization.scalar_types.uint8
else:
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = quantization.scalar_types.uint8b128
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx
self.perm = weight.perm
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert quantization is not None
A_flat = A.view(-1, A.shape[-1])
C = quantization.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.quant_type,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
def _check_valid_shape(in_features: int, out_features: int):
if (in_features % 128 != 0 or out_features % 64 != 0) and (
in_features % 64 != 0 or out_features % 128 != 0
):
raise ValueError(
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)

View File

@ -1,359 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
quantization = None
class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
@dataclass
class MarlinWeight(Weight):
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): bfloat16/float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.s.dtype in [torch.float16, torch.bfloat16]
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert quantization is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {groupsize}"
self.B = weight.B
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert quantization is not None
C = quantization.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
self.workspace,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
MARLIN_TILE_SIZE = 16
@dataclass
class GPTQMarlin24Weight:
"""
GPTQ-Marlin 2:4 weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
B_meta (torch.Tensor): metadata for 2:4 sparsity.
s (torch.Tensor): float16 scales.
bits: quantized weight size.
"""
weight_packed: torch.Tensor
meta: torch.Tensor
scale_packed: torch.Tensor
bits: int
def __post_init__(self):
assert self.weight_packed.dtype == torch.int32
assert self.meta.dtype == torch.int16
assert self.scale_packed.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert quantization is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
raise RuntimeError(
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
)
in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.scale_packed.shape[1]
groupsize = (
-1
if weight.scale_packed.shape[0] == 1
else in_features // weight.scale_packed.shape[0]
)
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
raise RuntimeError(
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
if weight.bits == 4:
self.quant_type = quantization.scalar_types.uint4b8
else:
self.quant_type = quantization.scalar_types.uint8b128
weights_per_int32 = 32 // weight.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
assert (
out_features % weights_per_int32 == 0
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
assert (
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
if groupsize != -1 and in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
)
self.weight_packed = weight.weight_packed
self.meta = weight.meta
self.scale_packed = weight.scale_packed
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int,
device=weight.weight_packed.device,
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert quantization is not None
C = quantization.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.weight_packed,
self.meta,
self.scale_packed,
self.workspace,
self.quant_type,
A.shape[0],
self.scale_packed.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))
if self.bias is not None:
C += self.bias
return C

View File

@ -1,137 +0,0 @@
import functools
from typing import List, Tuple
import numpy
import torch
quantization = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def _check_marlin_kernels():
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if quantization is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
@functools.cache
def get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def permute_scales(scales: torch.Tensor):
scale_perm, scale_perm_single = get_perms()
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
else:
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
return scales.reshape((-1, out_features)).contiguous()
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
scale_perm, _ = get_perms()
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp

View File

@ -398,6 +398,7 @@ class FlashGemmaModel(torch.nn.Module):
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
prefill_cache_indices: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
@ -479,6 +480,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
block_tables,
slots,
seqlen,
adapter_data,
prefill_cache_indices,
hpu_attention_meta,
)

View File

@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
prefix,
weights,
)
elif config.quantize == "marlin":
raise RuntimeError(
"GPT-2 models with marlin quantization are not yet supported"
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)

View File

@ -111,6 +111,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:

View File

@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq", "marlin"]:
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads

View File

@ -32,10 +32,6 @@ def load_multi_mqa(
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
elif config.quantize == "marlin":
raise RuntimeError(
"santacoder models with marlin quantization are not yet supported"
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size

View File

@ -588,7 +588,7 @@ class Seq2SeqLM(Model):
aliases=aliases,
weights_loader=weights_loader,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
if config.quantize in ["awq", "gptq"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)

View File

@ -4,9 +4,7 @@ from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
WeightsLoader,
)
@ -129,64 +127,13 @@ def get_loader(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
if can_use_gptq_marlin(
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
):
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
return GPTQMarlinWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
else:
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight
return DefaultWeightsLoader(BNBWeight)
elif quantize == "bitsandbytes-fp4":
from text_generation_server.layers.bnb import BNBFP4Weight
return DefaultWeightsLoader(BNBFP4Weight)
elif quantize == "bitsandbytes-nf4":
from text_generation_server.layers.bnb import BNBNF4Weight
return DefaultWeightsLoader(BNBNF4Weight)
elif quantize == "eetq":
from text_generation_server.layers.eetq import EETQWeight
return DefaultWeightsLoader(EETQWeight)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader