mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
This change adds support for wNa16 int checkpoints with 2:4 sparsity using Marlin 2:4 kernels.
197 lines
7.2 KiB
Python
197 lines
7.2 KiB
Python
from typing import Any, Dict, List, Union
|
|
|
|
from compressed_tensors import QuantizationConfig, QuantizationStatus
|
|
from compressed_tensors.config import CompressionFormat
|
|
from compressed_tensors.quantization import (
|
|
QuantizationScheme,
|
|
QuantizationType,
|
|
find_name_or_class_matches,
|
|
)
|
|
from loguru import logger
|
|
from pydantic import ValidationError
|
|
from torch import nn
|
|
|
|
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
|
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
|
|
from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
|
|
WNA16Int24Loader,
|
|
)
|
|
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
|
|
from text_generation_server.utils.log import log_once
|
|
from text_generation_server.utils.weights import (
|
|
DefaultWeightsLoader,
|
|
UnquantizedWeight,
|
|
Weights,
|
|
WeightsLoader,
|
|
)
|
|
|
|
# compressed-tensors can match modules as quantization targets. However,
|
|
# they need to be objects rather than classes or class names. Since we
|
|
# need to match `Linear` targets, make an instance that can be re-used.
|
|
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
|
|
|
|
|
|
class CompressedTensorsLoader(WeightsLoader):
|
|
"""Loader for checkpoints stored in the compressed-tensors format."""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
quantization_config_raw = config.get("quantization_config")
|
|
if quantization_config_raw is None:
|
|
# `compression_config` was renamed to `quantization_config`; support
|
|
# retained for backward compatibility.
|
|
quantization_config_raw = config.get("compression_config")
|
|
if quantization_config_raw is None:
|
|
raise ValueError(
|
|
"Checkpoint does not have compressed-tensors configuration"
|
|
)
|
|
|
|
try:
|
|
quantization_config = QuantizationConfig.model_validate(
|
|
quantization_config_raw
|
|
)
|
|
except ValidationError as e:
|
|
raise ValueError("Cannot parse compressed-tensors configuration") from e
|
|
|
|
if quantization_config.quantization_status not in (
|
|
QuantizationStatus.COMPRESSED,
|
|
QuantizationStatus.FROZEN,
|
|
):
|
|
raise ValueError(
|
|
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
|
|
)
|
|
|
|
self.ignore = (
|
|
quantization_config.ignore if quantization_config.ignore is not None else []
|
|
)
|
|
self.loaders = self._get_target_loaders(quantization_config)
|
|
|
|
for target, loader in self.loaders.items():
|
|
log_once(
|
|
logger.info,
|
|
f"Using {loader} for compressed-tensors target '{target}'",
|
|
)
|
|
|
|
def get_weights(self, weights: Weights, prefix: str):
|
|
loader = self._lookup_loader(prefix)
|
|
return loader.get_weights(weights, prefix)
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: "Weights",
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
loader = self._lookup_loader(prefix)
|
|
return loader.get_weights_col_packed(weights, prefix, block_sizes)
|
|
|
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
loader = self._lookup_loader(prefixes[0])
|
|
return loader.get_multi_weights_col(weights, prefixes, dim)
|
|
|
|
def get_weights_row(self, weights: Weights, prefix: str):
|
|
loader = self._lookup_loader(prefix)
|
|
return loader.get_weights_row(weights, prefix)
|
|
|
|
def _get_target_loaders(
|
|
self, quantization_config: QuantizationConfig
|
|
) -> Dict[str, WeightsLoader]:
|
|
"""
|
|
A compressed-tensors checkpoint can use different quantizations
|
|
for different targets. This method returns a dictionary with a
|
|
loader per target.
|
|
"""
|
|
|
|
loaders: Dict[str, WeightsLoader] = {}
|
|
|
|
format = quantization_config.format
|
|
|
|
for group_name, group in quantization_config.config_groups.items():
|
|
# The group configuration can be a string, but does that ever
|
|
# happen in a serialized quantization config?
|
|
assert isinstance(group, QuantizationScheme)
|
|
|
|
loader = self._create_loader_for_group(format, group_name, group)
|
|
|
|
# A quantized parameter group can have multiple targets, add the
|
|
# loader for all the targets.
|
|
for target in group.targets:
|
|
if target in loaders:
|
|
raise ValueError(
|
|
f"Target '{target} has multiple configured loaders'"
|
|
)
|
|
loaders[target] = loader
|
|
|
|
return loaders
|
|
|
|
def _create_loader_for_group(
|
|
self, format: str, group_name: str, group: QuantizationScheme
|
|
) -> WeightsLoader:
|
|
"""
|
|
Find and create a loader for the group with the given quantization
|
|
scheme.
|
|
"""
|
|
# NOTE: we ignore group.output_activations because we don't support
|
|
# output quantization yet.
|
|
|
|
input_activations = group.input_activations
|
|
weights = group.weights
|
|
if (
|
|
format
|
|
in {
|
|
CompressionFormat.float_quantized.value,
|
|
CompressionFormat.naive_quantized.value,
|
|
}
|
|
and weights is not None
|
|
and weights.type == QuantizationType.FLOAT
|
|
and weights.num_bits == 8
|
|
):
|
|
# FP W8A8 or W8A16.
|
|
return W8ANFpLoader(input_activations=input_activations, weights=weights)
|
|
elif (
|
|
format == CompressionFormat.pack_quantized.value
|
|
and weights is not None
|
|
and weights.type == QuantizationType.INT
|
|
and weights.num_bits in (4, 8)
|
|
):
|
|
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
|
|
return WNA16IntLoader(weights)
|
|
elif (
|
|
format == CompressionFormat.marlin_24.value
|
|
and weights is not None
|
|
and weights.type == QuantizationType.INT
|
|
and weights.num_bits in (4, 8)
|
|
):
|
|
return WNA16Int24Loader(weights)
|
|
elif (
|
|
format
|
|
in {
|
|
CompressionFormat.int_quantized.value,
|
|
CompressionFormat.naive_quantized.value,
|
|
}
|
|
and weights is not None
|
|
and weights.type == QuantizationType.INT
|
|
and weights.num_bits == 8
|
|
):
|
|
return W8A8IntLoader(input_args=input_activations, weight_args=weights)
|
|
else:
|
|
raise ValueError(
|
|
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
|
|
)
|
|
|
|
def _lookup_loader(self, prefix: str) -> WeightsLoader:
|
|
"""
|
|
Look up the loader to use for a given parameter name (prefix).
|
|
"""
|
|
|
|
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
|
|
return DefaultWeightsLoader(UnquantizedWeight)
|
|
|
|
# We currently only handle linear layers, so unconditionally pass
|
|
# a `Linear` instance.
|
|
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
|
|
if len(targets) == 0:
|
|
raise ValueError(
|
|
f"Cannot find compressed-tensors target for prefix: {prefix}"
|
|
)
|
|
return self.loaders[targets[0]]
|