From ddc0dd57f79fdc68a726030aa8eeb3ade4717742 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 7 May 2024 10:08:50 +0000 Subject: [PATCH] Fixes. --- .../text_generation_server/layers/__init__.py | 7 +- .../layers/awq/conversion_utils.py | 97 +++++++++++++++++++ .../layers/awq/quantize/qmodule.py | 50 ++++++++++ .../layers/gptq/__init__.py | 2 +- .../layers/gptq/exllamav2.py | 2 + .../layers/layernorm.py | 46 +-------- .../text_generation_server/layers/linear.py | 7 +- .../text_generation_server/layers/medusa.py | 4 + .../text_generation_server/layers/rotary.py | 1 + .../text_generation_server/models/__init__.py | 27 ++++++ .../custom_modeling/flash_dbrx_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 4 +- .../custom_modeling/flash_llama_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 8 +- .../custom_modeling/flash_phi_modeling.py | 6 +- .../custom_modeling/flash_qwen2_modeling.py | 4 +- .../custom_modeling/flash_rw_modeling.py | 8 +- .../flash_santacoder_modeling.py | 6 +- .../flash_starcoder2_modeling.py | 8 +- .../custom_modeling/idefics_modeling.py | 2 +- server/text_generation_server/server.py | 2 +- .../text_generation_server/utils/weights.py | 10 +- 22 files changed, 234 insertions(+), 73 deletions(-) create mode 100644 server/text_generation_server/layers/awq/conversion_utils.py create mode 100644 server/text_generation_server/layers/awq/quantize/qmodule.py diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index c727c1c7..4906aa2a 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -8,7 +8,6 @@ from text_generation_server.layers.linear import ( get_linear, FastLinear, ) -from text_generation_server.layers.layernorm import ( - get_linear, - FastLinear, -) + +# Just to add the `load` methods. +from text_generation_server.layers.layernorm import load_layer_norm diff --git a/server/text_generation_server/layers/awq/conversion_utils.py b/server/text_generation_server/layers/awq/conversion_utils.py new file mode 100644 index 00000000..b19eafbb --- /dev/null +++ b/server/text_generation_server/layers/awq/conversion_utils.py @@ -0,0 +1,97 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_gptq(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py new file mode 100644 index 00000000..ca8caf50 --- /dev/null +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -0,0 +1,50 @@ +# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py + +import math +import torch +import torch.nn as nn +import awq_inference_engine # with CUDA kernels + + +# class ScaledActivation(nn.Module): +# def __init__(self, module, scales): +# super().__init__() +# self.act = module +# self.scales = nn.Parameter(scales.data) +# +# def forward(self, x): +# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +class WQLinear(nn.Module): + def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): + super().__init__() + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = qweight.shape[0] + self.out_features = qweight.shape[1] * 32 // w_bit + + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else self.in_features + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert self.out_features % (32 // self.w_bit) == 0 + + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + if bias: + self.bias = bias + else: + self.bias = None + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features,) + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index d3b10665..1c46f493 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -36,4 +36,4 @@ elif CAN_EXLLAMA: except ImportError: pass -from text_generation_server.layers.gptq.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq.quant_linear import QuantLinear diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 80836a95..321ced97 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -119,6 +119,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): none_tensor, temp_dq, ) + else: + RuntimeError("Cannot create handle") DEVICE = None diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index 8865be25..15d24e80 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -80,7 +80,7 @@ elif SYSTEM == "rocm": hidden_states += residual residual = hidden_states - return super(FastLayerNorm, self).forward(hidden_states), residual + return super().forward(hidden_states), residual elif SYSTEM == "xpu": import intel_extension_for_pytorch as ipex @@ -96,50 +96,6 @@ elif SYSTEM == "xpu": return out, res_out -class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if SYSTEM == "xpu": - res_out = hidden_states - out = ipex.llm.functional.add_layer_norm( - residual, hidden_states, self.weight, self.bias, self.eps, True - ) - if residual is not None: - res_out = residual - return out, res_out - elif hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": - if residual is not None: - hidden_states += residual - residual = hidden_states - - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states - - return normed_hidden_states, residual - - class FastRMSNorm(nn.Module): def __init__(self, weight: torch.Tensor, eps: float): super().__init__() diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 822d43f3..d137a500 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,4 +1,5 @@ import torch +from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM @@ -97,7 +98,7 @@ def get_linear(weight, bias, quantize): if use_exllama: try: - from text_generation_server.utils.gptq.quant_linear import ( + from text_generation_server.layers.gptq import ( ExllamaQuantLinear, ) except ImportError: @@ -109,7 +110,7 @@ def get_linear(weight, bias, quantize): qweight, qzeros, scales, g_idx, bias, bits, groupsize ) else: - from text_generation_server.utils.gptq.quant_linear import QuantLinear + from text_generation_server.layers.gptq.quant_linear import QuantLinear linear = QuantLinear( qweight, @@ -133,7 +134,7 @@ def get_linear(weight, bias, quantize): "to use Exllama/GPTQ kernels for AWQ inference." ) try: - from text_generation_server.utils.awq.quantize.qmodule import WQLinear + from text_generation_server.layers.awq.quantize.qmodule import WQLinear linear = WQLinear( w_bit=bits, diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 3b0e6b57..4ac86978 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -3,6 +3,10 @@ from torch import nn from typing import Tuple, Optional from text_generation_server.utils.speculate import get_speculate from text_generation_server.layers.linear import FastLinear +from text_generation_server.layers.tensor_parallel import ( + TensorParallelHead, + TensorParallelColumnLinear, +) class ResBlock(torch.nn.Module): diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 5fc8d87c..503dd554 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,3 +1,4 @@ +import os import torch from torch import nn diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index b52765d7..3e6c88e8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -48,6 +48,33 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True +from text_generation_server.models.flash_rw import FlashRWSharded +from text_generation_server.models.flash_neox import FlashNeoXSharded +from text_generation_server.models.flash_llama import ( + FlashLlama, +) +from text_generation_server.models.flash_qwen2 import ( + FlashQwen2, +) +from text_generation_server.models.flash_cohere import ( + FlashCohere, +) +from text_generation_server.models.flash_gemma import ( + FlashGemma, +) +from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, +) +from text_generation_server.models.idefics import IDEFICSSharded +from text_generation_server.models.llava_next import LlavaNext +from text_generation_server.models.idefics2 import Idefics2 +from text_generation_server.models.flash_mistral import FlashMistral +from text_generation_server.models.flash_mixtral import FlashMixtral +from text_generation_server.models.flash_phi import FlashPhi +from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 +from text_generation_server.models.flash_dbrx import FlashDbrx +from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA + try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 29e83876..9d652b67 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -241,7 +241,7 @@ def _load_gqa(config, prefix: str, weights): log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.layers.awq import ( + from text_generation_server.layers.awq.conveersion_utils import ( fast_awq_to_gptq, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index ba0ee621..43b90bdd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -31,9 +31,11 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ca0deba8..a7969494 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -31,9 +31,11 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 745315a8..d45cab2e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -34,10 +34,14 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - FastLayerNorm, - PositionRotaryEmbedding, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) def load_row(config, prefix: str, weights, bias: bool): diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6bc30b09..f2efb538 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -11,11 +11,15 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) class PhiConfig(PretrainedConfig): diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 7f2998a4..3a6d2db5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -10,9 +10,11 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, +) +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.layers.layernorm import ( FastRMSNorm, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index ed68d6c6..52ea3ae1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -13,10 +13,14 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - FastLayerNorm, - PositionRotaryEmbedding, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) def load_row(config, prefix: str, weights, bias: bool): diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c9cb3950..d2f6d9af 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -11,9 +11,11 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, SpeculativeHead, TensorParallelEmbedding, - FastLayerNorm, get_linear, ) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) def load_multi_mqa( @@ -80,7 +82,7 @@ def _load_multi_mqa_gptq( g_idx = g_idx.to(device=weights.device) elif quant_method == "awq": g_idx = None - from text_generation_server.layers.awq import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index daebb23d..3e2ce4f9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -31,11 +31,15 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - PositionRotaryEmbedding, SpeculativeHead, get_linear, - FastRMSNorm, +) +from text_generation_server.layers.layernorm import ( FastLayerNorm, + FastRMSNorm, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index b0a7f04f..ec3f900b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -52,9 +52,9 @@ from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, SpeculativeHead, - PositionRotaryEmbedding, FastLinear, ) +from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 566703c8..9d0571a6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. - from text_generation_server.layers import ( + from text_generation_server.layers.gptq import ( create_exllama_buffers, set_device, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 51845a73..6af7d3fb 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -171,7 +171,7 @@ class Weights: log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.layers.awq import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) @@ -227,7 +227,7 @@ class Weights: bits, groupsize, desc_act, quant_method = self._get_gptq_params() - from text_generation_server.layers import HAS_EXLLAMA + from text_generation_server.layers.gptq import HAS_EXLLAMA use_exllama = ( bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act @@ -242,7 +242,7 @@ class Weights: log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.layers.awq import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) @@ -321,7 +321,7 @@ class Weights: # it would require to reorder input activations that are split unto several GPUs use_exllama = False - from text_generation_server.layers import HAS_EXLLAMA, CAN_EXLLAMA + from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA if use_exllama: if not HAS_EXLLAMA: @@ -348,7 +348,7 @@ class Weights: log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) - from text_generation_server.layers.awq import ( + from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, )