diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 9c9b69d1a..df5f59892 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,76 +1,12 @@ import os -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Union import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader - - -@dataclass -class GPTQWeight(Weight): - qweight: torch.Tensor - qzeros: torch.Tensor - scales: torch.Tensor - g_idx: Optional[torch.Tensor] - bits: int - groupsize: int - use_awq_kernel: bool - use_exllama: bool - - def __post_init__(self): - if self.scales.dtype == torch.float: - self.scales = self.scales.half() - - @property - def device(self) -> torch.device: - return self.qweight.device - - def get_linear(self, bias: torch.Tensor): - if self.use_awq_kernel: - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) - try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear - - return WQLinear( - w_bit=self.bits, - group_size=self.groupsize, - qweight=self.qweight, - qzeros=self.qzeros, - scales=self.scales, - bias=bias, - ) - except ImportError: - raise NotImplementedError( - "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" - ) - elif self.use_exllama: - try: - from text_generation_server.layers.gptq import ExllamaQuantLinear - except ImportError: - raise NotImplementedError( - "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - return ExllamaQuantLinear(self, bias) - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - return QuantLinear( - self.qweight, - self.qzeros, - self.scales, - self.g_idx, - bias, - self.bits, - self.groupsize, - ) +from text_generation_server.utils.weights import Weights, WeightsLoader +from text_generation_server.layers.gptq.gptq_types import GPTQWeight class GPTQWeightsLoader(WeightsLoader): diff --git a/server/text_generation_server/layers/gptq/gptq_types.py b/server/text_generation_server/layers/gptq/gptq_types.py new file mode 100644 index 000000000..9ac75385a --- /dev/null +++ b/server/text_generation_server/layers/gptq/gptq_types.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Optional +import torch +from text_generation_server.utils.weights import Weight +from text_generation_server.utils.import_utils import SYSTEM + + +@dataclass +class GPTQWeight(Weight): + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: Optional[torch.Tensor] + bits: int + groupsize: int + use_awq_kernel: bool + use_exllama: bool + + def __post_init__(self): + if self.scales.dtype == torch.float: + self.scales = self.scales.half() + + @property + def device(self) -> torch.device: + return self.qweight.device + + def get_linear(self, bias: torch.Tensor): + if self.use_awq_kernel: + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + return WQLinear( + w_bit=self.bits, + group_size=self.groupsize, + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bias=bias, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + elif self.use_exllama: + try: + from text_generation_server.layers.gptq import ExllamaQuantLinear + except ImportError: + raise NotImplementedError( + "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + return ExllamaQuantLinear(self, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + return QuantLinear( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + bias, + self.bits, + self.groupsize, + )