mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
Quantized weights were loaded in the `Weights` class, but this was getting quite unwieldy, where every higher level method to load weights was a long conditional to cover all the different quantizers. This change moves loading of quantized weights out of the `Weights` class. This is done by defining a simple `WeightsLoader` interface that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`, and `MarlinWeightsLoader`. These implementations are in the quantizers' respective modules. The `Weights` class provides the low-level load operations (such as loading tensors or sharded tensors), but delegates loads that need quantizer-specific weight processing to a loader. The loaders still use the low-level functionality provided by `Weights`. I initially tried making a hierarchy where a class like `GPTQWeights` would inherit from `Weights`. But it is not very flexible (e.g. does not work well with the new weight storage mock used in tests) and the implicit indirections made the code harder to follow.
84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
import torch
|
|
from typing import List, Union
|
|
from dataclasses import dataclass
|
|
|
|
from text_generation_server.utils.weights import WeightsLoader, Weights
|
|
|
|
|
|
@dataclass
|
|
class Exl2Weight:
|
|
"""
|
|
Exllama2 exl2 quantized weights.
|
|
"""
|
|
|
|
q_weight: torch.Tensor
|
|
q_scale: torch.Tensor
|
|
q_invperm: torch.Tensor
|
|
q_scale_max: torch.Tensor
|
|
q_groups: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
self.q_scale_max /= 256
|
|
self.q_invperm = self.q_invperm.short()
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.q_weight.device
|
|
|
|
|
|
class Exl2WeightsLoader(WeightsLoader):
|
|
"""Loader for exl2-quantized weights."""
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: Weights,
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
|
|
|
def get_weights_col(self, weights: Weights, prefix: str):
|
|
try:
|
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
)
|
|
|
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
|
|
return Exl2Weight(
|
|
q_weight=q_weight,
|
|
q_scale=q_scale,
|
|
q_invperm=q_invperm,
|
|
q_scale_max=q_scale_max,
|
|
q_groups=q_groups,
|
|
)
|
|
|
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
|
|
|
def get_weights_row(self, weights: Weights, prefix: str):
|
|
try:
|
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
)
|
|
|
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
|
|
return Exl2Weight(
|
|
q_weight=q_weight,
|
|
q_scale=q_scale,
|
|
q_invperm=q_invperm,
|
|
q_scale_max=q_scale_max,
|
|
q_groups=q_groups,
|
|
)
|