mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Preventing using exllama when act_order=True
This commit is contained in:
parent
238cc311f1
commit
9f42e5f6fd
@ -19,6 +19,7 @@ from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_AWQ = True
|
||||
try:
|
||||
@ -35,10 +36,11 @@ HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
logger.warning(
|
||||
V2 = False
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
||||
)
|
||||
V2 = False
|
||||
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
|
6
server/text_generation_server/utils/log.py
Normal file
6
server/text_generation_server/utils/log.py
Normal file
@ -0,0 +1,6 @@
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def log_once(log, msg:str):
|
||||
log(msg)
|
@ -6,6 +6,7 @@ import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
class Weights:
|
||||
@ -161,7 +162,7 @@ class Weights:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, _ = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
@ -211,10 +212,10 @@ class Weights:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, desc_act = self._get_gptq_params()
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
@ -240,11 +241,15 @@ class Weights:
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, desc_act = self._get_gptq_params()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
if g_idx is not None:
|
||||
@ -274,12 +279,16 @@ class Weights:
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
logger.warning(
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
log_once(
|
||||
logger.info,
|
||||
f"Using exllama kernels v{HAS_EXLLAMA}"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
@ -295,7 +304,7 @@ class Weights:
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
bits, groupsize, _ = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
@ -314,18 +323,20 @@ class Weights:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> Tuple[int, int]:
|
||||
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
desc_act = False
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
return bits, groupsize, desc_act
|
||||
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
@ -340,6 +351,7 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
@ -353,6 +365,7 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
@ -366,5 +379,6 @@ class Weights:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["w_bit"]
|
||||
self.gptq_groupsize = data["q_group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
except Exception:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user