From 67d687609ba0a32e04f2474f05e6255a7c42ad4b Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:16:58 +0000 Subject: [PATCH] cleanup --- .../custom_modeling/flash_santacoder_modeling.py | 15 +++------------ .../utils/gptq/quant_linear.py | 2 -- server/text_generation_server/utils/layers.py | 2 +- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 3e534012..d49254e1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -21,17 +21,13 @@ from text_generation_server.utils.layers import ( get_linear, ) -from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear -from custom_kernels.exllama import prepare_buffers, set_tuning_params - def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): if config.quantize == "gptq": - layer = _load_multi_mqa_gptq( + return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) - return layer else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size @@ -190,21 +186,18 @@ def load_col(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool): - quantize = config.quantize - if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=quantize) + weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear( - get_linear(weight, bias, quantize), process_group=weights.process_group + get_linear(weight, bias, config.quantize), process_group=weights.process_group ) @@ -272,8 +265,6 @@ class FlashMQAttention(torch.nn.Module): # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # output - attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index aa831ea2..6b6d5cb4 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -6,8 +6,6 @@ from torch.cuda.amp import custom_bwd, custom_fwd import torch -from loguru import logger - try: from custom_kernels.exllama import make_q4, q4_matmul except Exception as e: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 122fb884..63b9a406 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,7 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear from typing import Optional -from loguru import logger + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps):