Preventing using exllama when act_order=True

This commit is contained in:
Nicolas Patry 2023-12-21 15:05:05 +00:00
parent 238cc311f1
commit 9f42e5f6fd
3 changed files with 33 additions and 11 deletions

View File

@ -19,6 +19,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True
try:
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning(
V2 = False
log_once(
logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False

View File

@ -0,0 +1,6 @@
from functools import lru_cache
@lru_cache(10)
def log_once(log, msg:str):
log(msg)

View File

@ -6,6 +6,7 @@ import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.utils.log import log_once
class Weights:
@ -161,7 +162,7 @@ class Weights:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
bits, groupsize, _ = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
@ -211,10 +212,10 @@ class Weights:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -240,11 +241,15 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
if bits != 4:
use_exllama = False
if desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
@ -274,12 +279,16 @@ class Weights:
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
logger.warning(
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
)
use_exllama = False
else:
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
log_once(
logger.info,
f"Using exllama kernels v{HAS_EXLLAMA}"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
@ -295,7 +304,7 @@ class Weights:
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize = self._get_gptq_params()
bits, groupsize, _ = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -314,18 +323,20 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int]:
def _get_gptq_params(self) -> Tuple[int, int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
except Exception:
raise e
return bits, groupsize
return bits, groupsize, desc_act
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
@ -340,6 +351,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
@ -353,6 +365,7 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception:
filename = "quant_config.json"
try:
@ -366,5 +379,6 @@ class Weights:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception:
pass