refine the code according to the review command

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-10-14 21:01:54 -07:00
parent 7c6230c59a
commit b069d2c131
9 changed files with 210 additions and 69 deletions

View File

@ -145,6 +145,7 @@ RUN update-alternatives --set cc /usr/bin/gcc
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30 RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
RUN update-alternatives --set c++ /usr/bin/g++ RUN update-alternatives --set c++ /usr/bin/g++
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
@ -176,6 +177,7 @@ RUN case ${TARGETPLATFORM} in \
RUN conda install -c conda-forge gperftools mkl RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl

View File

@ -0,0 +1,48 @@
from typing import Optional
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -3,12 +3,7 @@
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM import awq_inference_engine # with CUDA kernels
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import awq_inference_engine # with CUDA kernels
# class ScaledActivation(nn.Module): # class ScaledActivation(nn.Module):
@ -43,29 +38,12 @@ class WQLinear(nn.Module):
self.qzeros = qzeros self.qzeros = qzeros
self.scales = scales self.scales = scales
self.bias = bias self.bias = bias
if SYSTEM == "ipex":
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
if SYSTEM == "ipex": out = awq_inference_engine.gemm_forward_cuda(
out = self.woq_linear(x.reshape(-1, x.shape[-1])) x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
else: )
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -36,7 +36,12 @@ class GPTQWeight(Weight):
"to use Exllama/GPTQ kernels for AWQ inference." "to use Exllama/GPTQ kernels for AWQ inference."
) )
try: try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear if SYSTEM == "ipex":
from text_generation_server.layers.awq.quantize.ipex import WQLinear
else:
from text_generation_server.layers.awq.quantize.qmodule import (
WQLinear,
)
return WQLinear( return WQLinear(
w_bit=self.bits, w_bit=self.bits,
@ -60,7 +65,10 @@ class GPTQWeight(Weight):
return ExllamaQuantLinear(self, bias) return ExllamaQuantLinear(self, bias)
else: else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear if SYSTEM == "ipex":
from text_generation_server.layers.gptq.ipex import QuantLinear
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear( return QuantLinear(
self.qweight, self.qweight,

View File

@ -0,0 +1,126 @@
import math
import numpy as np
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.infeatures,
self.outfeatures,
bias=self.bias,
group_size=self.groupsize,
g_idx=g_idx,
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -7,10 +7,6 @@ from torch.cuda.amp import custom_fwd
import triton import triton
import triton.language as tl import triton.language as tl
from . import custom_autotune from . import custom_autotune
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@ -268,21 +264,6 @@ class QuantLinear(nn.Module):
self.outfeatures = qweight.shape[1] self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits self.infeatures = qweight.shape[0] * 32 // bits
if SYSTEM == "ipex" and bits == 4:
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.infeatures,
self.outfeatures,
bias=self.bias,
group_size=self.groupsize,
g_idx=g_idx,
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@classmethod @classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias): def new(cls, bits, groupsize, infeatures, outfeatures, bias):
@ -365,17 +346,14 @@ class QuantLinear(nn.Module):
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,) out_shape = x.shape[:-1] + (self.outfeatures,)
if SYSTEM == "ipex" and self.bits == 4: out = QuantLinearFunction.apply(
out = self.woq_linear(x.reshape(-1, x.shape[-1])) x.reshape(-1, x.shape[-1]),
else: self.qweight,
out = QuantLinearFunction.apply( self.scales,
x.reshape(-1, x.shape[-1]), self.qzeros,
self.qweight, self.g_idx,
self.scales, self.bits,
self.qzeros, self.maxq,
self.g_idx, )
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -12,7 +12,12 @@ from huggingface_hub import HfApi
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
from text_generation_server.layers.gptq.ipex import QuantLinear
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error from text_generation_server.layers.gptq.utils import torch_snr_error

View File

@ -385,8 +385,11 @@ def get_model(
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. if SYSTEM == "ipex" and not hasattr(torch, "xpu"):
dtype = torch.float16 dtype = torch.bfloat16
else:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8": elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE

View File

@ -951,13 +951,6 @@ class FlashCausalLM(Model):
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
if (
quantize in ["awq", "exl2", "gptq", "marlin"]
and dtype == torch.float16
):
# Float16 doesn't exist on target.
dtype = torch.bfloat16
kv_cache_dtype = torch.bfloat16
init_cpu_threads_env(rank_id=rank, world_size=world_size) init_cpu_threads_env(rank_id=rank, world_size=world_size)
else: else:
raise NotImplementedError(f"{model_class} is only available on GPU") raise NotImplementedError(f"{model_class} is only available on GPU")