diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 08d672f3..d5adbd32 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -63,27 +63,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.quantize in ["gptq", "awq"]: - has_exllama_layers = False - for _, module in self.model.model.named_modules(): - if hasattr(module, "QUANT_TYPE"): - has_exllama_layers = True - break + if self.quantize == "gptq": + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.layers import ( + create_exllama_buffers, + set_device, + ) - if has_exllama_layers: - try: - # When using GPTQ or AWQ, Exllama kernels need some global kernels - # For which we have the finale shapes only after the model has loaded - # This will allocate those buffers. - from text_generation_server.utils.layers import ( - create_exllama_buffers, - set_device, - ) - - set_device(self.model.device) - create_exllama_buffers(request.max_prefill_tokens) - except ImportError: - pass + set_device(self.model.device) + create_exllama_buffers(request.max_prefill_tokens) + except ImportError: + pass if ( self.model.batch_type == IdeficsCausalLMBatch diff --git a/server/text_generation_server/utils/awq/pack_utils.py b/server/text_generation_server/utils/awq/pack_utils.py deleted file mode 100644 index d144b3cd..00000000 --- a/server/text_generation_server/utils/awq/pack_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from typing import List - - -AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] -REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] - - -def pack(imatrix: torch.Tensor, direction: str = "column"): - """ - Packs a 4-bit integer matrix into a packed 32-bit integer matrix. - Args: - imatrix (torch.Tensor): matrix of integers - direction (str): direction of packing, either "column" or "row" - Returns: - qmatrix (torch.Tensor): packed matrix of integers - """ - imatrix = imatrix.to(torch.int8) - imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow - - shifts = torch.arange(0, 32, 4, device=imatrix.device) - - if direction == "column": - imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) - qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) - - elif direction == "row": - imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) - qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) - - qmatrix = qmatrix.to(torch.int32) - - return qmatrix - - -def unpack(qmatrix: torch.Tensor, direction: str = "column"): - """ - Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. - Args: - qmatrix (torch.Tensor): matrix of packed integers - direction (str): direction of unpacking, either "column" or "row" - Returns: - imatrix (torch.Tensor): matrix of integers - """ - shifts = torch.arange(0, 32, 4, device=qmatrix.device) - - if direction == "column": - imatrix = torch.bitwise_right_shift( - qmatrix[:, :, None], shifts[None, None, :] - ).view(qmatrix.shape[0], -1) - - elif direction == "row": - imatrix = torch.bitwise_right_shift( - qmatrix[:, None, :], shifts[None, :, None] - ).view(-1, qmatrix.shape[-1]) - - imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow - - return imatrix - - -def apply_order( - imatrix: torch.Tensor, - direction: str = "column", - order: List[int] = AWQ_PACK_ORDER, -): - """ - Applies the order to a 4-bit integer matrix. - Args: - imatrix (torch.Tensor): matrix of integers - direction (str): direction of applying order, either "column" or "row" - order (List[int]): order to apply, default is AWQ_PACK_ORDER - Returns: - imatrix (torch.Tensor): matrix of integers - """ - if direction == "column": - imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) - elif direction == "row": - imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) - - return imatrix - - -def fast_awq_to_gptq(qweight, qzeros): - # awq uses column packing for both weights and zeros - izeros = unpack(qzeros, direction="column") - iweights = unpack(qweight, direction="column") - - # Reverse the order of the iweight and izeros tensors - izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) - iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) - # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) - izeros = izeros - 1 - # exllama uses row packing for weights and column packing for zeros - qzeros = pack(izeros, direction="column") - qweight = pack(iweights, direction="row") - - return qweight, qzeros diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7f20081d..b9b1dfac 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -25,8 +25,6 @@ HAS_AWQ = True try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear except ImportError: - from text_generation_server.utils.awq.pack_utils import fast_awq_to_gptq - HAS_AWQ = False try: @@ -351,23 +349,19 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) - if HAS_AWQ: - linear = WQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, - bias=bias is not None, + if IS_ROCM_SYSTEM: + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." ) - elif HAS_EXLLAMA: - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - linear = ExllamaQuantLinear( - qweight, qzeros, scales, None, bias, bits, groupsize - ) - else: - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - linear = QuantLinear(qweight, qzeros, scales, None, bias, bits, groupsize) + linear = WQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear