From 9f42e5f6fdb33e493e03481c688277cedf21029e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 21 Dec 2023 15:05:05 +0000 Subject: [PATCH] Preventing using exllama when act_order=True --- server/text_generation_server/utils/layers.py | 6 ++-- server/text_generation_server/utils/log.py | 6 ++++ .../text_generation_server/utils/weights.py | 32 +++++++++++++------ 3 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 server/text_generation_server/utils/log.py diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 011a9382..6648b55a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -19,6 +19,7 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.log import log_once HAS_AWQ = True try: @@ -35,10 +36,11 @@ HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: - logger.warning( + V2 = False + log_once( + logger.warning, "Disabling exllama v2 and using v1 instead because there are issues when sharding" ) - V2 = False if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py new file mode 100644 index 00000000..d831fa76 --- /dev/null +++ b/server/text_generation_server/utils/log.py @@ -0,0 +1,6 @@ +from functools import lru_cache + + +@lru_cache(10) +def log_once(log, msg:str): + log(msg) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index dbdab0f5..ee1899ab 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -6,6 +6,7 @@ import torch from loguru import logger from huggingface_hub import hf_hub_download import json +from text_generation_server.utils.log import log_once class Weights: @@ -161,7 +162,7 @@ class Weights: else: g_idx = None - bits, groupsize = self._get_gptq_params() + bits, groupsize, _ = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: slice_ = self._get_slice(f"{prefix}.weight") @@ -211,10 +212,10 @@ class Weights: else: g_idx = None - bits, groupsize = self._get_gptq_params() + bits, groupsize, desc_act = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA - use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" + use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -240,11 +241,15 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize = self._get_gptq_params() + bits, groupsize, desc_act = self._get_gptq_params() if bits != 4: use_exllama = False + if desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + if self.process_group.size() > 1: g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: @@ -274,12 +279,16 @@ class Weights: if use_exllama: if not HAS_EXLLAMA: if CAN_EXLLAMA: - logger.warning( + log_once( + logger.warning, "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" ) use_exllama = False else: - logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") + log_once( + logger.info, + f"Using exllama kernels v{HAS_EXLLAMA}" + ) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) @@ -295,7 +304,7 @@ class Weights: weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": - bits, groupsize = self._get_gptq_params() + bits, groupsize, _ = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -314,18 +323,20 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int]: + def _get_gptq_params(self) -> Tuple[int, int, int]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() + desc_act = False except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize + desc_act = getattr(self, "gptq_desc_act", False) except Exception: raise e - return bits, groupsize + return bits, groupsize, desc_act def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -340,6 +351,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: @@ -353,6 +365,7 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] + self.gptq_desc_act = data["desc_act"] except Exception: filename = "quant_config.json" try: @@ -366,5 +379,6 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] + self.gptq_desc_act = data["desc_act"] except Exception: pass