2024-05-28 09:51:31 +00:00
|
|
|
from dataclasses import dataclass
|
2024-05-13 10:44:30 +00:00
|
|
|
import os
|
2024-05-28 09:51:31 +00:00
|
|
|
from typing import Optional
|
2024-05-13 10:44:30 +00:00
|
|
|
import torch
|
|
|
|
from text_generation_server.utils.import_utils import (
|
|
|
|
SYSTEM,
|
|
|
|
)
|
|
|
|
|
2024-05-28 09:51:31 +00:00
|
|
|
|
2024-07-01 10:59:12 +00:00
|
|
|
@dataclass
|
|
|
|
class GPTQParams:
|
|
|
|
bits: int
|
|
|
|
checkpoint_format: Optional[str]
|
|
|
|
groupsize: int
|
|
|
|
desc_act: bool
|
|
|
|
quant_method: str
|
|
|
|
sym: bool
|
|
|
|
|
|
|
|
|
2024-05-28 09:51:31 +00:00
|
|
|
@dataclass
|
|
|
|
class GPTQWeight:
|
|
|
|
qweight: torch.Tensor
|
|
|
|
qzeros: torch.Tensor
|
|
|
|
scales: torch.Tensor
|
|
|
|
g_idx: Optional[torch.Tensor]
|
|
|
|
bits: int
|
|
|
|
groupsize: int
|
|
|
|
use_exllama: bool
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if self.scales.dtype == torch.float:
|
|
|
|
self.scales = self.scales.half()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device(self) -> torch.device:
|
|
|
|
return self.qweight.device
|
|
|
|
|
|
|
|
|
2024-05-13 10:44:30 +00:00
|
|
|
try:
|
|
|
|
major, _minor = torch.cuda.get_device_capability()
|
|
|
|
except Exception:
|
|
|
|
major = 1
|
|
|
|
|
|
|
|
HAS_EXLLAMA = False
|
|
|
|
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
|
|
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
|
|
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
|
|
|
HAS_EXLLAMA = False
|
|
|
|
elif CAN_EXLLAMA:
|
|
|
|
try:
|
|
|
|
if V2:
|
|
|
|
from text_generation_server.layers.gptq.exllamav2 import (
|
|
|
|
QuantLinear as ExllamaQuantLinear,
|
|
|
|
create_exllama_buffers,
|
|
|
|
set_device,
|
|
|
|
)
|
|
|
|
|
|
|
|
HAS_EXLLAMA = "2"
|
|
|
|
else:
|
|
|
|
from text_generation_server.layers.gptq.exllama import (
|
|
|
|
Ex4bitLinear as ExllamaQuantLinear,
|
|
|
|
create_exllama_buffers,
|
|
|
|
set_device,
|
|
|
|
)
|
|
|
|
|
|
|
|
HAS_EXLLAMA = "1"
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|