mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
add gptq and awq int4 support in intel platform
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
e4201f44cf
commit
0b02d45a05
@ -103,12 +103,21 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
curl \
|
curl \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
make \
|
make \
|
||||||
g++ \
|
g++-12 \
|
||||||
|
gcc-12 \
|
||||||
git \
|
git \
|
||||||
wget \
|
wget \
|
||||||
cmake \
|
cmake \
|
||||||
libnuma-dev
|
libnuma-dev
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
|
||||||
|
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
|
||||||
|
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
|
||||||
|
RUN update-alternatives --set cc /usr/bin/gcc
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
|
||||||
|
RUN update-alternatives --set c++ /usr/bin/g++
|
||||||
|
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
@ -133,21 +142,19 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
|
|
||||||
RUN conda install -c conda-forge gperftools mkl
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install triton numa
|
RUN pip install triton numa
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f
|
||||||
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0
|
||||||
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
|
||||||
|
|
||||||
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
||||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||||
|
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
@ -3,7 +3,12 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import awq_inference_engine # with CUDA kernels
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
else:
|
||||||
|
import awq_inference_engine # with CUDA kernels
|
||||||
|
|
||||||
|
|
||||||
# class ScaledActivation(nn.Module):
|
# class ScaledActivation(nn.Module):
|
||||||
@ -38,12 +43,29 @@ class WQLinear(nn.Module):
|
|||||||
self.qzeros = qzeros
|
self.qzeros = qzeros
|
||||||
self.scales = scales
|
self.scales = scales
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
self.woq_linear = (
|
||||||
|
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.qzeros,
|
||||||
|
self.in_features,
|
||||||
|
self.out_features,
|
||||||
|
bias=self.bias,
|
||||||
|
group_size=self.group_size,
|
||||||
|
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
|
||||||
|
dtype=ipex.llm.quantization.QuantDtype.INT4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.out_features,)
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
out = awq_inference_engine.gemm_forward_cuda(
|
if SYSTEM == "ipex":
|
||||||
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||||
)
|
else:
|
||||||
|
out = awq_inference_engine.gemm_forward_cuda(
|
||||||
|
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
||||||
|
)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
@ -298,6 +298,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
|
desc_act = self.desc_act
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
@ -321,7 +322,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
if (
|
if (
|
||||||
not torch.equal(
|
not torch.equal(
|
||||||
g_idx.cpu(),
|
(g_idx - g_idx[0]).cpu(),
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[i // self.groupsize for i in range(g_idx.shape[0])],
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -332,6 +333,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
desc_act = True
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import (
|
from text_generation_server.layers.gptq import (
|
||||||
CAN_EXLLAMA,
|
CAN_EXLLAMA,
|
||||||
@ -350,16 +352,16 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
else:
|
else:
|
||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
if use_exllama and self.groupsize != -1:
|
if not desc_act and self.groupsize != -1:
|
||||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
if g_idx is not None:
|
||||||
|
# qzeros, scales sharded, and g_idx must be adjusted accordingly
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
else:
|
else:
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
if use_exllama and g_idx is not None:
|
|
||||||
g_idx = g_idx - g_idx[0]
|
|
||||||
|
|
||||||
if self.quantize == "gptq" and self.quant_method == "awq":
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
@ -7,6 +7,10 @@ from torch.cuda.amp import custom_fwd
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from . import custom_autotune
|
from . import custom_autotune
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||||
@ -264,6 +268,21 @@ class QuantLinear(nn.Module):
|
|||||||
|
|
||||||
self.outfeatures = qweight.shape[1]
|
self.outfeatures = qweight.shape[1]
|
||||||
self.infeatures = qweight.shape[0] * 32 // bits
|
self.infeatures = qweight.shape[0] * 32 // bits
|
||||||
|
if SYSTEM == "ipex" and bits == 4:
|
||||||
|
self.woq_linear = (
|
||||||
|
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.qzeros,
|
||||||
|
self.infeatures,
|
||||||
|
self.outfeatures,
|
||||||
|
bias=self.bias,
|
||||||
|
group_size=self.groupsize,
|
||||||
|
g_idx=g_idx,
|
||||||
|
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
|
||||||
|
dtype=ipex.llm.quantization.QuantDtype.INT4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
@ -346,14 +365,17 @@ class QuantLinear(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
out = QuantLinearFunction.apply(
|
if SYSTEM == "ipex" and self.bits == 4:
|
||||||
x.reshape(-1, x.shape[-1]),
|
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||||
self.qweight,
|
else:
|
||||||
self.scales,
|
out = QuantLinearFunction.apply(
|
||||||
self.qzeros,
|
x.reshape(-1, x.shape[-1]),
|
||||||
self.g_idx,
|
self.qweight,
|
||||||
self.bits,
|
self.scales,
|
||||||
self.maxq,
|
self.qzeros,
|
||||||
)
|
self.g_idx,
|
||||||
|
self.bits,
|
||||||
|
self.maxq,
|
||||||
|
)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
@ -885,6 +885,8 @@ class FlashCausalLM(Model):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
# Float16 doesn't exist on target.
|
# Float16 doesn't exist on target.
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
dtype = torch.bfloat16
|
||||||
init_cpu_threads_env(rank_id=rank, world_size=world_size)
|
init_cpu_threads_env(rank_id=rank, world_size=world_size)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{model_class} is only available on GPU")
|
raise NotImplementedError(f"{model_class} is only available on GPU")
|
||||||
|
Loading…
Reference in New Issue
Block a user