mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
remove unused quantization code and enable awq/gptq int4
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
fdf0733f56
commit
9914ffe1f1
@ -7,10 +7,6 @@ from text_generation_server.utils.weights import (
|
||||
)
|
||||
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
|
||||
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
|
||||
from text_generation_server.layers.marlin.marlin import (
|
||||
MarlinWeight,
|
||||
MarlinWeightsLoader,
|
||||
)
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Dict, Union
|
||||
from pathlib import Path
|
||||
@ -40,11 +36,6 @@ def gptq_weights_loader_awq():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def marlin_weights_loader():
|
||||
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
|
||||
|
||||
|
||||
dummy_file_system = {
|
||||
"test_weights": {
|
||||
"layer.0.weight": torch.tensor(
|
||||
@ -125,10 +116,6 @@ dummy_file_system = {
|
||||
"gptq_bits": torch.tensor([8], dtype=torch.float32),
|
||||
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
|
||||
},
|
||||
"test_get_weights_col_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
},
|
||||
"test_get_weights_row_gptq": {
|
||||
"weight.qweight": torch.tensor(
|
||||
[
|
||||
@ -273,18 +260,6 @@ dummy_file_system = {
|
||||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
"test_get_weights_row_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||
},
|
||||
"test_get_multi_weights_col_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||
},
|
||||
"test_get_weights_col_packed_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -718,33 +693,6 @@ def test_get_weights_col_exl2():
|
||||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
)
|
||||
|
||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||
|
||||
|
||||
# test_get_weights_col_packed
|
||||
|
||||
|
||||
@ -868,36 +816,6 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_packed_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
)
|
||||
|
||||
print(expected_weight)
|
||||
|
||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||
|
||||
|
||||
# test_get_multi_weights_col
|
||||
|
||||
|
||||
@ -1004,34 +922,6 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_col_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
)
|
||||
|
||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||
|
||||
|
||||
# test_get_weights_row
|
||||
|
||||
|
||||
@ -1148,30 +1038,3 @@ def test_get_weights_row_gptq(gptq_weights_loader):
|
||||
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
|
||||
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
|
||||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_weights_row_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_row_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
)
|
||||
|
||||
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
|
||||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||
|
@ -16,15 +16,9 @@ app = typer.Typer()
|
||||
|
||||
|
||||
class Quantization(str, Enum):
|
||||
bitsandbytes = "bitsandbytes"
|
||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
eetq = "eetq"
|
||||
exl2 = "exl2"
|
||||
fp8 = "fp8"
|
||||
marlin = "marlin"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
@ -1,19 +1,93 @@
|
||||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import awq_inference_engine # with CUDA kernels
|
||||
|
||||
try:
|
||||
import habana_frameworks.torch.hpu # noqa: F401
|
||||
|
||||
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||
except Exception as e:
|
||||
hpu_import_exception = e
|
||||
|
||||
def error_raiser_hpu(*args, **kwargs):
|
||||
raise ValueError(
|
||||
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||
)
|
||||
|
||||
convert_from_uint4 = error_raiser_hpu
|
||||
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
|
||||
|
||||
# class ScaledActivation(nn.Module):
|
||||
# def __init__(self, module, scales):
|
||||
# super().__init__()
|
||||
# self.act = module
|
||||
# self.scales = nn.Parameter(scales.data)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
# unpacking columnwise
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8 # smallest dtype available
|
||||
)
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
# unpacking columnwise
|
||||
if qzeros is not None:
|
||||
izeros = torch.bitwise_right_shift(
|
||||
qzeros[:, :, None], shifts[None, None, :]
|
||||
).to(
|
||||
torch.int8 # smallest dtype available
|
||||
)
|
||||
izeros = izeros.view(izeros.shape[0], -1)
|
||||
else:
|
||||
izeros = qzeros
|
||||
|
||||
return iweights, izeros
|
||||
|
||||
|
||||
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||
reverse_order_tensor = torch.arange(
|
||||
iweights.shape[-1],
|
||||
dtype=torch.int32,
|
||||
device=izeros.device,
|
||||
)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
if izeros is not None:
|
||||
izeros = izeros[:, reverse_order_tensor]
|
||||
iweights = iweights[:, reverse_order_tensor]
|
||||
|
||||
return iweights, izeros
|
||||
|
||||
|
||||
def unpack_weight_and_zeros(qweight, qzeros, bits):
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight, izeros = unpack_awq(qweight, qzeros, bits)
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
|
||||
|
||||
# overflow checks
|
||||
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||
|
||||
return iweight, izeros
|
||||
|
||||
|
||||
def pack_tensor(input, bits=4):
|
||||
normal = input.to(torch.int32)
|
||||
q = torch.zeros(
|
||||
(normal.shape[0], normal.shape[1] // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
device=input.device,
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < q.shape[1]:
|
||||
for j in range(i, i + (32 // bits)):
|
||||
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||
i += 32 // bits
|
||||
col += 1
|
||||
q = q.to(torch.int32)
|
||||
return q
|
||||
|
||||
|
||||
class WQLinear(nn.Module):
|
||||
@ -38,12 +112,23 @@ class WQLinear(nn.Module):
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.bias = bias
|
||||
self._preprocessing()
|
||||
|
||||
def _preprocessing(self):
|
||||
device = self.qweight.device
|
||||
weight, zeros = unpack_weight_and_zeros(
|
||||
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
|
||||
)
|
||||
self.qweight = pack_tensor(weight).to(device)
|
||||
self.qzeros = pack_tensor(zeros).to(device)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.out_features,)
|
||||
out = awq_inference_engine.gemm_forward_cuda(
|
||||
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||
outputs = torch.matmul(x, weights)
|
||||
|
||||
outputs = outputs + self.bias if self.bias is not None else outputs
|
||||
outputs = outputs.reshape(out_shape)
|
||||
return outputs
|
||||
|
@ -1,3 +0,0 @@
|
||||
from .loader import CompressedTensorsLoader
|
||||
|
||||
__all__ = ["CompressedTensorsLoader"]
|
@ -1,196 +0,0 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from compressed_tensors import QuantizationConfig, QuantizationStatus
|
||||
from compressed_tensors.config import CompressionFormat
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationScheme,
|
||||
QuantizationType,
|
||||
find_name_or_class_matches,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from torch import nn
|
||||
|
||||
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
||||
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
|
||||
WNA16Int24Loader,
|
||||
)
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
# compressed-tensors can match modules as quantization targets. However,
|
||||
# they need to be objects rather than classes or class names. Since we
|
||||
# need to match `Linear` targets, make an instance that can be re-used.
|
||||
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
|
||||
|
||||
|
||||
class CompressedTensorsLoader(WeightsLoader):
|
||||
"""Loader for checkpoints stored in the compressed-tensors format."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
quantization_config_raw = config.get("quantization_config")
|
||||
if quantization_config_raw is None:
|
||||
# `compression_config` was renamed to `quantization_config`; support
|
||||
# retained for backward compatibility.
|
||||
quantization_config_raw = config.get("compression_config")
|
||||
if quantization_config_raw is None:
|
||||
raise ValueError(
|
||||
"Checkpoint does not have compressed-tensors configuration"
|
||||
)
|
||||
|
||||
try:
|
||||
quantization_config = QuantizationConfig.model_validate(
|
||||
quantization_config_raw
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError("Cannot parse compressed-tensors configuration") from e
|
||||
|
||||
if quantization_config.quantization_status not in (
|
||||
QuantizationStatus.COMPRESSED,
|
||||
QuantizationStatus.FROZEN,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
|
||||
)
|
||||
|
||||
self.ignore = (
|
||||
quantization_config.ignore if quantization_config.ignore is not None else []
|
||||
)
|
||||
self.loaders = self._get_target_loaders(quantization_config)
|
||||
|
||||
for target, loader in self.loaders.items():
|
||||
log_once(
|
||||
logger.info,
|
||||
f"Using {loader} for compressed-tensors target '{target}'",
|
||||
)
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights(weights, prefix)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights_col_packed(weights, prefix, block_sizes)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
loader = self._lookup_loader(prefixes[0])
|
||||
return loader.get_multi_weights_col(weights, prefixes, dim)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights_row(weights, prefix)
|
||||
|
||||
def _get_target_loaders(
|
||||
self, quantization_config: QuantizationConfig
|
||||
) -> Dict[str, WeightsLoader]:
|
||||
"""
|
||||
A compressed-tensors checkpoint can use different quantizations
|
||||
for different targets. This method returns a dictionary with a
|
||||
loader per target.
|
||||
"""
|
||||
|
||||
loaders: Dict[str, WeightsLoader] = {}
|
||||
|
||||
format = quantization_config.format
|
||||
|
||||
for group_name, group in quantization_config.config_groups.items():
|
||||
# The group configuration can be a string, but does that ever
|
||||
# happen in a serialized quantization config?
|
||||
assert isinstance(group, QuantizationScheme)
|
||||
|
||||
loader = self._create_loader_for_group(format, group_name, group)
|
||||
|
||||
# A quantized parameter group can have multiple targets, add the
|
||||
# loader for all the targets.
|
||||
for target in group.targets:
|
||||
if target in loaders:
|
||||
raise ValueError(
|
||||
f"Target '{target} has multiple configured loaders'"
|
||||
)
|
||||
loaders[target] = loader
|
||||
|
||||
return loaders
|
||||
|
||||
def _create_loader_for_group(
|
||||
self, format: str, group_name: str, group: QuantizationScheme
|
||||
) -> WeightsLoader:
|
||||
"""
|
||||
Find and create a loader for the group with the given quantization
|
||||
scheme.
|
||||
"""
|
||||
# NOTE: we ignore group.output_activations because we don't support
|
||||
# output quantization yet.
|
||||
|
||||
input_activations = group.input_activations
|
||||
weights = group.weights
|
||||
if (
|
||||
format
|
||||
in {
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.naive_quantized.value,
|
||||
}
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.FLOAT
|
||||
and weights.num_bits == 8
|
||||
):
|
||||
# FP W8A8 or W8A16.
|
||||
return W8ANFpLoader(input_activations=input_activations, weights=weights)
|
||||
elif (
|
||||
format == CompressionFormat.pack_quantized.value
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits in (4, 8)
|
||||
):
|
||||
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
|
||||
return WNA16IntLoader(weights)
|
||||
elif (
|
||||
format == CompressionFormat.marlin_24.value
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits in (4, 8)
|
||||
):
|
||||
return WNA16Int24Loader(weights)
|
||||
elif (
|
||||
format
|
||||
in {
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.naive_quantized.value,
|
||||
}
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits == 8
|
||||
):
|
||||
return W8A8IntLoader(input_args=input_activations, weight_args=weights)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
|
||||
)
|
||||
|
||||
def _lookup_loader(self, prefix: str) -> WeightsLoader:
|
||||
"""
|
||||
Look up the loader to use for a given parameter name (prefix).
|
||||
"""
|
||||
|
||||
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
|
||||
return DefaultWeightsLoader(UnquantizedWeight)
|
||||
|
||||
# We currently only handle linear layers, so unconditionally pass
|
||||
# a `Linear` instance.
|
||||
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
|
||||
if len(targets) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find compressed-tensors target for prefix: {prefix}"
|
||||
)
|
||||
return self.loaders[targets[0]]
|
@ -1,239 +0,0 @@
|
||||
from typing import List, Optional, Union, TypeVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
class W8A8IntLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for w8a8 integer compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_args: Optional[QuantizationArgs],
|
||||
weight_args: QuantizationArgs,
|
||||
):
|
||||
if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} only supports w8a8 int checkpoints"
|
||||
)
|
||||
|
||||
if not weight_args.symmetric:
|
||||
raise ValueError("Checkpoints with asymmetric weights are not supported")
|
||||
|
||||
self.load_weight_scale = not weight_args.dynamic
|
||||
|
||||
if input_args is not None:
|
||||
self.input_symmetric = input_args.symmetric
|
||||
|
||||
if not input_args.dynamic:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).",
|
||||
)
|
||||
else:
|
||||
self.input_symmetric = True
|
||||
|
||||
def __str__(self) -> str:
|
||||
def scale_to_str(scale):
|
||||
return "static" if scale else "dynamic"
|
||||
|
||||
def symmetric_to_str(symmetric):
|
||||
return "symmetric" if symmetric else "asymmetric"
|
||||
|
||||
return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))"
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||
)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if weight_scale.numel() > 1:
|
||||
weight_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
weight_scale = weight_scale.reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
w = torch.cat(w, dim=dim)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
|
||||
OtherT = TypeVar("OtherT")
|
||||
|
||||
|
||||
def _get_tensor_or_else(
|
||||
weights: Weights, prefix: str, other: OtherT
|
||||
) -> Union[torch.Tensor, OtherT]:
|
||||
# Even if a checkpoint uses e.g. zero-points, they can be elided:
|
||||
# https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105
|
||||
if weights.has_tensor(prefix):
|
||||
return weights.get_tensor(prefix, to_dtype=False)
|
||||
else:
|
||||
return other
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8Weight(Weight):
|
||||
input_symmetric: bool
|
||||
weight: torch.Tensor
|
||||
weight_scale: Optional[torch.Tensor]
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.weight_scale is None:
|
||||
assert quantization is not None
|
||||
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=qweight,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
else:
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
)
|
||||
|
||||
|
||||
class W8A8IntLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bias: Optional[torch.Tensor],
|
||||
input_symmetric: bool,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
weight_scale = weight_scale.to(torch.float32)
|
||||
|
||||
self.bias = bias
|
||||
self.input_symmetric = input_symmetric
|
||||
# cutlass kernels require transposed weights.
|
||||
self.weight = weight.t()
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
if input_symmetric:
|
||||
self.zero_point_adj = None
|
||||
else:
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp
|
||||
self.zero_point_adj = self.weight.sum(
|
||||
dim=0, keepdim=True, dtype=torch.int32
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
assert quantization is not None
|
||||
|
||||
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
|
||||
input=input,
|
||||
scale=None,
|
||||
azp=None,
|
||||
symmetric=self.input_symmetric,
|
||||
)
|
||||
|
||||
if self.input_symmetric:
|
||||
return quantization.cutlass_scaled_mm(
|
||||
a=qinput,
|
||||
b=self.weight,
|
||||
scale_a=input_scale,
|
||||
scale_b=self.weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=self.bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.zero_point_adj is not None
|
||||
and input_scale is not None
|
||||
and (self.input_symmetric or input_zero_point is not None)
|
||||
)
|
||||
|
||||
return quantization.cutlass_scaled_mm_azp(
|
||||
a=qinput,
|
||||
b=self.weight,
|
||||
scale_a=input_scale,
|
||||
scale_b=self.weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
azp_adj=self.zero_point_adj,
|
||||
azp=input_zero_point,
|
||||
bias=self.bias,
|
||||
)
|
@ -1,168 +0,0 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import (
|
||||
Fp8Weight,
|
||||
_load_scalar_or_matrix_scale,
|
||||
)
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class W8ANFpLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for W8A8/W8A16 FP compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_activations: Optional[QuantizationArgs],
|
||||
weights: QuantizationArgs,
|
||||
):
|
||||
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
|
||||
|
||||
# We ignore the `strategy` option which sets the scales to be
|
||||
# per-tensor, per-channel or per-token. What scales are supported
|
||||
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
|
||||
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
|
||||
# So, instead we try to use the best-possible configuration.
|
||||
|
||||
self.load_weight_scale = not weights.dynamic
|
||||
self.load_input_scale = (
|
||||
input_activations is not None and not input_activations.dynamic
|
||||
)
|
||||
self.force_w8a16 = (
|
||||
input_activations is not None and input_activations.num_bits == 16
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
def scale_to_str(scale):
|
||||
return "static" if scale else "dynamic"
|
||||
|
||||
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if weight_scale.numel() > 1:
|
||||
weight_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||
if input_scale.numel() > 1:
|
||||
input_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.input_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
input_scale = input_scale.reshape(-1).max()
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
@ -1,188 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class WNA16IntLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for W4A16/W8A16 INT compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: QuantizationArgs):
|
||||
self.weights = weights
|
||||
self.desc_act = self.weights.actorder == ActivationOrdering.GROUP
|
||||
self.groupsize = (
|
||||
-1 if self.weights.group_size is None else self.weights.group_size
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
quantization_type = f"W{self.weights.num_bits}A16"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type})"
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t()
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
g_idx = weights.get_tensor(f"{prefix}.weight_g_idx")
|
||||
|
||||
scales = weights.get_tensor(f"{prefix}.weight.scales").t()
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
try:
|
||||
weight_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes
|
||||
).t()
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
|
||||
)
|
||||
scales = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
||||
).t()
|
||||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
zero_point = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=0, block_sizes=block_sizes
|
||||
).t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
try:
|
||||
weight_packed = torch.cat(
|
||||
[
|
||||
weights.get_sharded(f"{p}.weight_packed", dim=0).t()
|
||||
for p in prefixes
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
zero_point = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1
|
||||
).t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t()
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
|
||||
else:
|
||||
zero_point = weights.get_sharded(
|
||||
f"{prefix}.weight_zero_point", dim=1
|
||||
).t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.weight_scale").t()
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t()
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
@ -1,101 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class WNA16Int24Loader(WeightsLoader):
|
||||
"""
|
||||
Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_args: QuantizationArgs):
|
||||
super().__init__()
|
||||
|
||||
if weight_args.type != QuantizationType.INT:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} only supports wNa8 int checkpoints"
|
||||
)
|
||||
|
||||
if weight_args.strategy == "group" and weight_args.group_size is None:
|
||||
raise ValueError("`group_size` must be set when `actorder` is `group`")
|
||||
|
||||
self.bits = weight_args.num_bits
|
||||
self.group_size = weight_args.group_size
|
||||
|
||||
def __str__(self) -> str:
|
||||
quantization_type = f"W{self.bits}A16 2:4 sparsity"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type})"
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
"""
|
||||
Get weights at the given prefix and apply without tensor paralllism.
|
||||
"""
|
||||
weight_packed = weights.get_tensor(f"{prefix}.weight_packed")
|
||||
meta = weights.get_tensor(f"{prefix}.meta")
|
||||
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
weight_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
meta = weights.get_packed_sharded(
|
||||
f"{prefix}.meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scale_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
weight_packed = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
meta = torch.cat(
|
||||
[weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scale_packed = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0)
|
||||
meta = weights.get_sharded(f"{prefix}.meta", dim=0)
|
||||
if self.group_size is None:
|
||||
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||
else:
|
||||
scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0)
|
||||
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
@ -7,7 +7,7 @@ from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
QuantLinear = None
|
||||
from .hpu import QuantLinear
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -215,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
self.bits == 4
|
||||
and HAS_EXLLAMA
|
||||
and self.quantize == "gptq"
|
||||
and not self.desc_act
|
||||
)
|
||||
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
@ -362,6 +355,3 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
else False
|
||||
)
|
||||
self.quant_method = "gptq"
|
||||
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
|
@ -1,125 +1,181 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [4]:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
self.woq_linear = (
|
||||
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.infeatures,
|
||||
self.outfeatures,
|
||||
bias=self.bias,
|
||||
group_size=self.groupsize,
|
||||
g_idx=g_idx,
|
||||
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
|
||||
dtype=ipex.llm.quantization.QuantDtype.INT4,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
if bits not in [4]:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||
return out.reshape(out_shape)
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
|
||||
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
|
||||
except Exception as e:
|
||||
hpu_import_exception = e
|
||||
|
||||
def error_raiser_hpu(*args, **kwargs):
|
||||
raise ValueError(
|
||||
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
|
||||
)
|
||||
|
||||
convert_from_uint4 = error_raiser_hpu
|
||||
|
||||
|
||||
def pack_tensor(input, bits=4):
|
||||
normal = input.to(torch.int32)
|
||||
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < q.shape[1]:
|
||||
for j in range(i, i + (32 // bits)):
|
||||
q[:, col] |= normal[:, j] << (bits * (j - i))
|
||||
i += 32 // bits
|
||||
col += 1
|
||||
q = q.to(torch.int32)
|
||||
return q
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [4]:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
self._preprocessing()
|
||||
|
||||
def unpack_zeros_from_cuda_old_format(self):
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||
self.wf.unsqueeze(0),
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
|
||||
self.scales.dtype
|
||||
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
|
||||
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
|
||||
return zeros
|
||||
|
||||
def unpack_weight_from_cuda_old_format(self):
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||
self.wf.unsqueeze(-1),
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
|
||||
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
|
||||
return weight
|
||||
|
||||
def _preprocessing(self):
|
||||
self.qweight = self.qweight.cpu()
|
||||
weight = self.unpack_weight_from_cuda_old_format()
|
||||
new_qweight = pack_tensor(weight)
|
||||
self.qweight = new_qweight.to("hpu")
|
||||
|
||||
# TODO: Support group indexing and remove the check
|
||||
columns = self.qweight.shape[0]
|
||||
g_idx_trivial = [i // self.group_size for i in range(columns)]
|
||||
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32)
|
||||
assert torch.equal(
|
||||
self.g_idx, g_idx_trivial
|
||||
), "Non-trivial tensor g_idx is not supported"
|
||||
|
||||
zeros = self.unpack_zeros_from_cuda_old_format().cpu()
|
||||
new_qzeros = pack_tensor(zeros)
|
||||
self.qzeros = new_qzeros.to("hpu")
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
if bits not in [4]:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
|
||||
out = torch.matmul(x, weight)
|
||||
out = out.reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out
|
@ -1,15 +0,0 @@
|
||||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinWeightsLoader,
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
|
||||
|
||||
__all__ = [
|
||||
"GPTQMarlinFP8Linear",
|
||||
"GPTQMarlinWeightsLoader",
|
||||
"MarlinWeightsLoader",
|
||||
"can_use_gptq_marlin",
|
||||
"repack_gptq_for_marlin",
|
||||
]
|
@ -1,141 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||
from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
permute_scales,
|
||||
)
|
||||
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
class GPTQMarlinFP8Linear(nn.Module):
|
||||
"""
|
||||
FP8 GPTQ-Marlin linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert quantization is not None
|
||||
|
||||
scales = scales.unsqueeze(0)
|
||||
if scales.shape[1] == 1:
|
||||
out_features, in_features = qweight.shape
|
||||
scales = scales.repeat(1, out_features)
|
||||
qweight, scales = repack_fp8_for_marlin(qweight, scales)
|
||||
|
||||
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = scales.shape[1]
|
||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||
|
||||
self.qweight = qweight
|
||||
self.scales = scales
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scales = fp8_quantize(weight)
|
||||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
**kwargs,
|
||||
):
|
||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert quantization is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = quantization.fp8_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.workspace,
|
||||
8,
|
||||
A_flat.shape[0],
|
||||
self.scales.shape[1],
|
||||
A_flat.shape[1],
|
||||
)
|
||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements).
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
|
||||
if fp8_tensor.shape[0] % 4 != 0:
|
||||
raise ValueError(
|
||||
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = torch.zeros(
|
||||
fp8_tensor.shape[0] // 4,
|
||||
fp8_tensor.shape[1],
|
||||
dtype=torch.int32,
|
||||
device=fp8_tensor.device,
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
||||
"""
|
||||
Repack FP8 tensor for GPTQ-Marlin.
|
||||
"""
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
|
||||
# Torch linear layers weights with shape [out_features, in_features],
|
||||
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
|
||||
# so transpose before packing.
|
||||
qweight = pack_fp8_as_int32(weight.t())
|
||||
|
||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
repacked = quantization.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, 8
|
||||
)
|
||||
|
||||
scales = permute_scales(scales)
|
||||
|
||||
return repacked, scales
|
@ -1,465 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
marlin_zero_points,
|
||||
permute_scales,
|
||||
unpack_cols,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
|
||||
GPTQ_MARLIN_BITS = [4, 8]
|
||||
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
def can_use_gptq_marlin(
|
||||
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bits: int,
|
||||
desc_act: bool,
|
||||
groupsize: int,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
self.groupsize = groupsize
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
scales = weights.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
qzeros = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
if not self.sym:
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
else:
|
||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
else:
|
||||
qzeros = None
|
||||
|
||||
if self.quant_method == "awq":
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
qzeros=qzeros,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
def _get_gptq_params(self, weights: Weights):
|
||||
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||
self.bits = weights.get_tensor("gptq_bits").item()
|
||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
self.desc_act = False
|
||||
# `server quantize` used asymmetric quantization unconditionally
|
||||
# before the `gptq_sym` setting tensor was added.
|
||||
self.sym = (
|
||||
weights.get_tensor("gptq_sym").item()
|
||||
if weights.has_tensor("gptq_sym")
|
||||
else False
|
||||
)
|
||||
self.quant_method = "gptq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinWeight(Weight):
|
||||
"""
|
||||
Repacked GPTQ Marlin weights.
|
||||
"""
|
||||
|
||||
qweight: torch.Tensor
|
||||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
perm: torch.Tensor
|
||||
bits: int
|
||||
is_full_k: bool
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.qweight.dtype == torch.int32
|
||||
assert self.scales.dtype in (torch.float16, torch.bfloat16)
|
||||
assert self.g_idx.dtype == torch.int32
|
||||
assert self.perm.dtype == torch.int32
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return GPTQMarlinLinear(
|
||||
weight=self,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
def repack_gptq_for_marlin(
|
||||
*,
|
||||
qweight: torch.Tensor,
|
||||
qzeros: Optional[torch.Tensor],
|
||||
scales: torch.Tensor,
|
||||
g_idx: Optional[torch.Tensor],
|
||||
bits: int,
|
||||
desc_act: bool,
|
||||
groupsize: int,
|
||||
quant_method: str,
|
||||
sym: bool,
|
||||
sharded_infeatures: bool,
|
||||
) -> GPTQMarlinWeight:
|
||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
||||
_check_marlin_kernels()
|
||||
assert quantization is not None
|
||||
|
||||
if bits not in GPTQ_MARLIN_BITS:
|
||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||
raise RuntimeError(
|
||||
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
|
||||
)
|
||||
|
||||
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
|
||||
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
|
||||
raise RuntimeError(
|
||||
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
||||
)
|
||||
if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"):
|
||||
raise RuntimeError(
|
||||
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
||||
)
|
||||
|
||||
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
|
||||
|
||||
weights_per_int = 32 // bits
|
||||
in_features = qweight.shape[0]
|
||||
out_features = qweight.shape[1]
|
||||
|
||||
# AWQ uses column packing, GPTQ uses row packing
|
||||
if quant_method == "awq":
|
||||
out_features *= weights_per_int
|
||||
else:
|
||||
in_features *= weights_per_int
|
||||
|
||||
if in_features % groupsize != 0:
|
||||
raise ValueError(
|
||||
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
|
||||
)
|
||||
|
||||
if g_idx is not None and desc_act and groupsize != -1:
|
||||
perm = torch.argsort(g_idx).to(torch.int)
|
||||
g_idx = g_idx[perm]
|
||||
else:
|
||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
|
||||
if quant_method == "awq":
|
||||
repacked = quantization.awq_marlin_repack(
|
||||
qweight, in_features, out_features, bits
|
||||
)
|
||||
if qzeros is not None:
|
||||
qzeros = awq_to_marlin_zero_points(
|
||||
qzeros,
|
||||
in_features // groupsize,
|
||||
out_features,
|
||||
bits,
|
||||
)
|
||||
|
||||
else:
|
||||
repacked = quantization.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, bits
|
||||
)
|
||||
|
||||
if qzeros is None:
|
||||
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
|
||||
scales = permute_scales(scales)
|
||||
|
||||
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
|
||||
|
||||
return GPTQMarlinWeight(
|
||||
qweight=repacked,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
perm=perm,
|
||||
bits=bits,
|
||||
is_full_k=is_full_k,
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlinLinear(nn.Module):
|
||||
"""
|
||||
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
|
||||
kernels.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
weight: GPTQMarlinWeight,
|
||||
bias: Optional[torch.Tensor],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert quantization is not None
|
||||
|
||||
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = weight.scales.shape[1]
|
||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||
|
||||
if weight.bits not in (4, 8):
|
||||
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
|
||||
|
||||
if weight.qzeros.numel() > 0:
|
||||
if weight.bits == 4:
|
||||
self.quant_type = quantization.scalar_types.uint4
|
||||
else:
|
||||
self.quant_type = quantization.scalar_types.uint8
|
||||
else:
|
||||
if weight.bits == 4:
|
||||
self.quant_type = quantization.scalar_types.uint4b8
|
||||
else:
|
||||
self.quant_type = quantization.scalar_types.uint8b128
|
||||
|
||||
self.is_full_k = weight.is_full_k
|
||||
|
||||
self.qweight = weight.qweight
|
||||
self.qzeros = weight.qzeros
|
||||
self.scales = weight.scales
|
||||
self.g_idx = weight.g_idx
|
||||
self.perm = weight.perm
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert quantization is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = quantization.gptq_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.perm,
|
||||
self.workspace,
|
||||
self.quant_type,
|
||||
A_flat.shape[0],
|
||||
self.scales.shape[1],
|
||||
A_flat.shape[1],
|
||||
self.is_full_k,
|
||||
self.qzeros.numel() > 0,
|
||||
True,
|
||||
)
|
||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
# AWQ zero-points are quantized and packed on the column dim.
|
||||
# In addition, the values are permuted based on dequantizer.
|
||||
# Here we undo both of these, and then apply marlin permutation
|
||||
# and pack it back.
|
||||
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
||||
|
||||
# Undo interleaving (use argsort(..) to get inverse perm)
|
||||
if num_bits == 4:
|
||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
||||
elif num_bits == 8:
|
||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
||||
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
||||
return marlin_zp
|
||||
|
||||
|
||||
def _check_valid_shape(in_features: int, out_features: int):
|
||||
if (in_features % 128 != 0 or out_features % 64 != 0) and (
|
||||
in_features % 64 != 0 or out_features % 128 != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
|
||||
" The shape elements must be divisible by (128, 64) or (64, 128)."
|
||||
)
|
@ -1,359 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
class MarlinWeightsLoader(WeightsLoader):
|
||||
"""Loader for Marlin-quantized weights."""
|
||||
|
||||
def __init__(self, *, bits: int, is_marlin_24: bool):
|
||||
self.bits = bits
|
||||
self.is_marlin_24 = is_marlin_24
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
Get weights at the given prefix and apply without tensor paralllism.
|
||||
"""
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = weights.get_tensor(f"{prefix}.B_24")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_tensor(f"{prefix}.B")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_marlin_24:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
B_meta = weights.get_packed_sharded(
|
||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = weights.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = weights.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
if self.is_marlin_24:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinWeight(Weight):
|
||||
"""
|
||||
Marlin weights.
|
||||
|
||||
Attributes:
|
||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||
s (torch.Tensor): bfloat16/float16 scales.
|
||||
"""
|
||||
|
||||
B: torch.Tensor
|
||||
s: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.B.dtype == torch.int32
|
||||
assert self.s.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return MarlinLinear(weight=self, bias=bias)
|
||||
|
||||
|
||||
class MarlinLinear(nn.Module):
|
||||
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert quantization is not None
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = weight.s.shape[1]
|
||||
assert (
|
||||
in_features % 128 == 0
|
||||
), f"Number of input features ({in_features}) not divisable by 128"
|
||||
assert (
|
||||
out_features % 256 == 0
|
||||
), f"Number of output features ({out_features}) not divisable by 256"
|
||||
|
||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
||||
assert groupsize in {
|
||||
-1,
|
||||
128,
|
||||
}, f"Group size must be -1 or 128, was {groupsize}"
|
||||
|
||||
self.B = weight.B
|
||||
self.s = weight.s
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert quantization is not None
|
||||
|
||||
C = quantization.marlin_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.s,
|
||||
self.workspace,
|
||||
A.shape[0],
|
||||
self.s.shape[1],
|
||||
A.shape[1],
|
||||
)
|
||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
||||
|
||||
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlin24Weight:
|
||||
"""
|
||||
GPTQ-Marlin 2:4 weights.
|
||||
|
||||
Attributes:
|
||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||
B_meta (torch.Tensor): metadata for 2:4 sparsity.
|
||||
s (torch.Tensor): float16 scales.
|
||||
bits: quantized weight size.
|
||||
"""
|
||||
|
||||
weight_packed: torch.Tensor
|
||||
meta: torch.Tensor
|
||||
scale_packed: torch.Tensor
|
||||
bits: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.weight_packed.dtype == torch.int32
|
||||
assert self.meta.dtype == torch.int16
|
||||
assert self.scale_packed.dtype == torch.float16
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return GPTQMarlin24Linear(
|
||||
weight=self,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlin24Linear(nn.Module):
|
||||
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert quantization is not None
|
||||
|
||||
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
supported_bits = ", ".join(
|
||||
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||
)
|
||||
|
||||
in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.scale_packed.shape[1]
|
||||
groupsize = (
|
||||
-1
|
||||
if weight.scale_packed.shape[0] == 1
|
||||
else in_features // weight.scale_packed.shape[0]
|
||||
)
|
||||
|
||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
supported_sizes = ", ".join(
|
||||
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
|
||||
)
|
||||
|
||||
if weight.bits == 4:
|
||||
self.quant_type = quantization.scalar_types.uint4b8
|
||||
else:
|
||||
self.quant_type = quantization.scalar_types.uint8b128
|
||||
weights_per_int32 = 32 // weight.bits
|
||||
|
||||
assert (
|
||||
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
|
||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
|
||||
assert (
|
||||
out_features % weights_per_int32 == 0
|
||||
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
|
||||
|
||||
assert (
|
||||
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
|
||||
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
|
||||
if groupsize != -1 and in_features % groupsize != 0:
|
||||
raise ValueError(
|
||||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||
)
|
||||
|
||||
self.weight_packed = weight.weight_packed
|
||||
self.meta = weight.meta
|
||||
self.scale_packed = weight.scale_packed
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
dtype=torch.int,
|
||||
device=weight.weight_packed.device,
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert quantization is not None
|
||||
|
||||
C = quantization.gptq_marlin_24_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.weight_packed,
|
||||
self.meta,
|
||||
self.scale_packed,
|
||||
self.workspace,
|
||||
self.quant_type,
|
||||
A.shape[0],
|
||||
self.scale_packed.shape[1],
|
||||
A.shape[1],
|
||||
)
|
||||
|
||||
C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
@ -1,137 +0,0 @@
|
||||
import functools
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
quantization = None
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
|
||||
def _check_marlin_kernels():
|
||||
raise NotImplementedError(
|
||||
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
||||
)
|
||||
|
||||
if quantization is None:
|
||||
raise NotImplementedError(
|
||||
"marlin is not installed, install it with: pip install server/marlin"
|
||||
)
|
||||
|
||||
|
||||
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
|
||||
@functools.cache
|
||||
def get_perms() -> Tuple[List[int], List[int]]:
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def permute_scales(scales: torch.Tensor):
|
||||
scale_perm, scale_perm_single = get_perms()
|
||||
out_features = scales.shape[1]
|
||||
if scales.shape[0] == 1:
|
||||
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
else:
|
||||
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
return scales.reshape((-1, out_features)).contiguous()
|
||||
|
||||
|
||||
# Functions below are from vLLM
|
||||
|
||||
|
||||
def get_pack_factor(bits: int) -> int:
|
||||
if 32 % bits != 0:
|
||||
raise ValueError(f"Cannot {bits} bit values into uint32")
|
||||
return 32 // bits
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k,
|
||||
size_n // pack_factor,
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor
|
||||
)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def marlin_zero_points(
|
||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
scale_perm, _ = get_perms()
|
||||
# Permute zero-points in a similar way to scales, but do not use the
|
||||
# "single" permutation, since zero-points are applied on every MMA
|
||||
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
zp = pack_cols(zp, num_bits, size_k, size_n)
|
||||
|
||||
return zp
|
@ -398,6 +398,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
adapter_data: Optional[torch.Tensor],
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> torch.Tensor:
|
||||
@ -479,6 +480,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta,
|
||||
)
|
||||
|
@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
prefix,
|
||||
weights,
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
raise RuntimeError(
|
||||
"GPT-2 models with marlin quantization are not yet supported"
|
||||
)
|
||||
else:
|
||||
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
||||
|
||||
|
@ -111,6 +111,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
seqlen=seqlen,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
|
||||
if lm_head_indices is not None:
|
||||
|
@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
|
@ -32,10 +32,6 @@ def load_multi_mqa(
|
||||
return _load_multi_mqa_gptq(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
raise RuntimeError(
|
||||
"santacoder models with marlin quantization are not yet supported"
|
||||
)
|
||||
else:
|
||||
return _load_multi_mqa(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
|
@ -588,7 +588,7 @@ class Seq2SeqLM(Model):
|
||||
aliases=aliases,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
if config.quantize in ["awq", "gptq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_class(config, weights)
|
||||
|
@ -4,9 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
@ -129,64 +127,13 @@ def get_loader(
|
||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||
)
|
||||
|
||||
if can_use_gptq_marlin(
|
||||
return GPTQWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
):
|
||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||
|
||||
return GPTQMarlinWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
)
|
||||
else:
|
||||
return GPTQWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
)
|
||||
elif quantize == "bitsandbytes":
|
||||
from text_generation_server.layers.bnb import BNBWeight
|
||||
|
||||
return DefaultWeightsLoader(BNBWeight)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
from text_generation_server.layers.bnb import BNBFP4Weight
|
||||
|
||||
return DefaultWeightsLoader(BNBFP4Weight)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
from text_generation_server.layers.bnb import BNBNF4Weight
|
||||
|
||||
return DefaultWeightsLoader(BNBNF4Weight)
|
||||
elif quantize == "eetq":
|
||||
from text_generation_server.layers.eetq import EETQWeight
|
||||
|
||||
return DefaultWeightsLoader(EETQWeight)
|
||||
elif quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||
|
||||
return Exl2WeightsLoader()
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||
|
||||
# TODO: improve check once we have one config type per quantize value
|
||||
if not isinstance(quantizer_config, _QuantizerConfig):
|
||||
raise ValueError(
|
||||
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
|
||||
)
|
||||
|
||||
return MarlinWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||
)
|
||||
elif quantize == "fp8" or quantize is None:
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
Loading…
Reference in New Issue
Block a user