diff --git a/Dockerfile_intel b/Dockerfile_intel index e4cd95ee..012cf722 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -145,6 +145,7 @@ RUN update-alternatives --set cc /usr/bin/gcc RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30 RUN update-alternatives --set c++ /usr/bin/g++ + ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 @@ -176,6 +177,7 @@ RUN case ${TARGETPLATFORM} in \ RUN conda install -c conda-forge gperftools mkl + RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl diff --git a/server/text_generation_server/layers/awq/quantize/ipex.py b/server/text_generation_server/layers/awq/quantize/ipex.py new file mode 100644 index 00000000..84cd7a21 --- /dev/null +++ b/server/text_generation_server/layers/awq/quantize/ipex.py @@ -0,0 +1,48 @@ +from typing import Optional +import torch +import torch.nn as nn +import intel_extension_for_pytorch as ipex + + +class WQLinear(nn.Module): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.bias = bias + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.in_features, + self.out_features, + bias=self.bias, + group_size=self.group_size, + quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index d59b1f18..391371a5 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -3,12 +3,7 @@ from typing import Optional import torch import torch.nn as nn -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex -else: - import awq_inference_engine # with CUDA kernels +import awq_inference_engine # with CUDA kernels # class ScaledActivation(nn.Module): @@ -43,29 +38,12 @@ class WQLinear(nn.Module): self.qzeros = qzeros self.scales = scales self.bias = bias - if SYSTEM == "ipex": - self.woq_linear = ( - ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( - self.qweight, - self.scales, - self.qzeros, - self.in_features, - self.out_features, - bias=self.bias, - group_size=self.group_size, - quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM, - dtype=ipex.llm.quantization.QuantDtype.INT4, - ) - ) @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) - if SYSTEM == "ipex": - out = self.woq_linear(x.reshape(-1, x.shape[-1])) - else: - out = awq_inference_engine.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 20db6565..020467f2 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -36,7 +36,12 @@ class GPTQWeight(Weight): "to use Exllama/GPTQ kernels for AWQ inference." ) try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear + if SYSTEM == "ipex": + from text_generation_server.layers.awq.quantize.ipex import WQLinear + else: + from text_generation_server.layers.awq.quantize.qmodule import ( + WQLinear, + ) return WQLinear( w_bit=self.bits, @@ -60,7 +65,10 @@ class GPTQWeight(Weight): return ExllamaQuantLinear(self, bias) else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear + if SYSTEM == "ipex": + from text_generation_server.layers.gptq.ipex import QuantLinear + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear return QuantLinear( self.qweight, diff --git a/server/text_generation_server/layers/gptq/ipex.py b/server/text_generation_server/layers/gptq/ipex.py new file mode 100644 index 00000000..ab9c9e24 --- /dev/null +++ b/server/text_generation_server/layers/gptq/ipex.py @@ -0,0 +1,126 @@ +import math +import numpy as np +import torch +import torch.nn as nn + +import intel_extension_for_pytorch as ipex + + +class QuantLinear(nn.Module): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + super().__init__() + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + self.register_buffer("scales", scales) + self.register_buffer("g_idx", g_idx) + if bias is not None: + self.register_buffer("bias", bias) + else: + self.bias = None + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize + + self.outfeatures = qweight.shape[1] + self.infeatures = qweight.shape[0] * 32 // bits + self.woq_linear = ( + ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( + self.qweight, + self.scales, + self.qzeros, + self.infeatures, + self.outfeatures, + bias=self.bias, + group_size=self.groupsize, + g_idx=g_idx, + quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, + dtype=ipex.llm.quantization.QuantDtype.INT4, + ) + ) + + @classmethod + def new(cls, bits, groupsize, infeatures, outfeatures, bias): + if bits not in [4]: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) + qzeros = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures // 32 * bits), + dtype=torch.int32, + ) + scales = torch.zeros( + (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16 + ) + g_idx = torch.tensor( + [i // groupsize for i in range(infeatures)], dtype=torch.int32 + ) + if bias: + bias = torch.zeros((outfeatures), dtype=torch.float16) + else: + bias = None + return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) + / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32 + ) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [4]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + out = self.woq_linear(x.reshape(-1, x.shape[-1])) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/quant_linear.py index 9dc7615e..736c357b 100644 --- a/server/text_generation_server/layers/gptq/quant_linear.py +++ b/server/text_generation_server/layers/gptq/quant_linear.py @@ -7,10 +7,6 @@ from torch.cuda.amp import custom_fwd import triton import triton.language as tl from . import custom_autotune -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM == "ipex": - import intel_extension_for_pytorch as ipex # code based https://github.com/fpgaminer/GPTQ-triton @@ -268,21 +264,6 @@ class QuantLinear(nn.Module): self.outfeatures = qweight.shape[1] self.infeatures = qweight.shape[0] * 32 // bits - if SYSTEM == "ipex" and bits == 4: - self.woq_linear = ( - ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight( - self.qweight, - self.scales, - self.qzeros, - self.infeatures, - self.outfeatures, - bias=self.bias, - group_size=self.groupsize, - g_idx=g_idx, - quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM, - dtype=ipex.llm.quantization.QuantDtype.INT4, - ) - ) @classmethod def new(cls, bits, groupsize, infeatures, outfeatures, bias): @@ -365,17 +346,14 @@ class QuantLinear(nn.Module): def forward(self, x): out_shape = x.shape[:-1] + (self.outfeatures,) - if SYSTEM == "ipex" and self.bits == 4: - out = self.woq_linear(x.reshape(-1, x.shape[-1])) - else: - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.maxq, + ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index b0086ea0..d87df5f2 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -12,7 +12,12 @@ from huggingface_hub import HfApi from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files -from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + from text_generation_server.layers.gptq.ipex import QuantLinear +else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index de0c66e7..0860e9ee 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -400,8 +400,11 @@ def get_model( if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: - # These quantizers only work with float16 params. - dtype = torch.float16 + if SYSTEM == "ipex" and not hasattr(torch, "xpu"): + dtype = torch.bfloat16 + else: + # These quantizers only work with float16 params. + dtype = torch.float16 elif quantize == "fp8": from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 07d65b77..bce459e3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1123,13 +1123,6 @@ class FlashCausalLM(Model): else: device = torch.device("cpu") dtype = torch.bfloat16 if dtype is None else dtype - if ( - quantize in ["awq", "exl2", "gptq", "marlin"] - and dtype == torch.float16 - ): - # Float16 doesn't exist on target. - dtype = torch.bfloat16 - kv_cache_dtype = torch.bfloat16 init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU")