mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
* Improve the handling of quantized weights Handling of quantized weights was split between two mechanisms: - For quantized checkpoints, we used the new weight loader infrastructure. - For quantization while loading (EETQ, FP8, bitsandbytes) we instead relied on conditional in `get_linear`. Weight loaders support context managers to selectively load particular layers with different weight loaders, which is useful for models like Idefics2 AWQ, which uses a quantized text model, but unquantized vision and connector models. However, the context manager would be overrided by `get_linear`, which string-checks `quantizer`. Also, the context manager would not work with EETQ, FP8, and bitsandbytes. This change migrates all quantizers to the weight loader infrastructure. This has several benefits: - We can use context managers with all quantizers. - All the implementation details move down to the quantizer layers, `get_linear` does not need to know how to handle quantizer linear layers. - All quantizer weights are strongly typed, we don't pass around raw tensors. - We don't have to pass around the `quantizer` string everywhere. * Exclude non-MLP layers when using FP8 quantization with Llama
89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Union
|
|
|
|
import torch
|
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
|
|
|
|
|
@dataclass
|
|
class Exl2Weight(Weight):
|
|
"""
|
|
Exllama2 exl2 quantized weights.
|
|
"""
|
|
|
|
q_weight: torch.Tensor
|
|
q_scale: torch.Tensor
|
|
q_invperm: torch.Tensor
|
|
q_scale_max: torch.Tensor
|
|
q_groups: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
self.q_scale_max /= 256
|
|
self.q_invperm = self.q_invperm.short()
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.q_weight.device
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
|
|
|
return ExllamaQuantLinear(self, bias)
|
|
|
|
|
|
class Exl2WeightsLoader(WeightsLoader):
|
|
"""Loader for exl2-quantized weights."""
|
|
|
|
def get_weights_col_packed(
|
|
self,
|
|
weights: Weights,
|
|
prefix: str,
|
|
block_sizes: Union[int, List[int]],
|
|
):
|
|
raise RuntimeError("Column-packed weights are not supported for exl")
|
|
|
|
def get_weights_col(self, weights: Weights, prefix: str):
|
|
try:
|
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
)
|
|
|
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
|
|
return Exl2Weight(
|
|
q_weight=q_weight,
|
|
q_scale=q_scale,
|
|
q_invperm=q_invperm,
|
|
q_scale_max=q_scale_max,
|
|
q_groups=q_groups,
|
|
)
|
|
|
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
|
|
|
def get_weights_row(self, weights: Weights, prefix: str):
|
|
try:
|
|
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
|
except RuntimeError:
|
|
raise RuntimeError(
|
|
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
|
)
|
|
|
|
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
|
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
|
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
|
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
|
|
|
return Exl2Weight(
|
|
q_weight=q_weight,
|
|
q_scale=q_scale,
|
|
q_invperm=q_invperm,
|
|
q_scale_max=q_scale_max,
|
|
q_groups=q_groups,
|
|
)
|