Fixing indirect GPTQ loads.

Loading GPTQ (or any kernel) should always be lazy.
Some platforms don't have GPTQ support and shouldn't fail because of
imports.
This commit is contained in:
Nicolas Patry 2024-06-03 08:20:37 +00:00
parent 799a193b10
commit 9a9b679c33
5 changed files with 14 additions and 6 deletions

View File

@ -6,6 +6,8 @@ from text_generation_server.utils.import_utils import (
SYSTEM,
)
raise RuntimeError("No")
@dataclass
class GPTQWeight:

View File

@ -2,8 +2,6 @@ from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
if SYSTEM == "rocm":
try:
@ -155,6 +153,8 @@ def get_linear(weight, bias, quantize):
quant_type="nf4",
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
@ -165,6 +165,8 @@ def get_linear(weight, bias, quantize):
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."

View File

@ -21,7 +21,6 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
@ -198,6 +197,8 @@ def _load_gqa(config, prefix: str, weights):
v_stop = v_offset + (rank + 1) * kv_block_size
if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]

View File

@ -5,7 +5,6 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -39,6 +38,8 @@ def load_multi_mqa(
def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
from text_generation_server.layers.gptq import GPTQWeight
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size()
rank = weights.process_group.rank()

View File

@ -7,8 +7,6 @@ import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_once
@ -221,6 +219,8 @@ class Weights:
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
@ -247,6 +247,8 @@ class Weights:
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1