mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Refactored a bit.
This commit is contained in:
parent
6bf7090ecd
commit
0860394489
@ -1,5 +1,5 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name="custom_kernels",
|
||||
@ -14,7 +14,7 @@ setup(
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
),
|
||||
CppExtension(
|
||||
CUDAExtension(
|
||||
name="custom_kernels.exllama",
|
||||
sources=[
|
||||
"custom_kernels/exllama/exllama_ext.cpp",
|
||||
|
@ -71,11 +71,8 @@ def _load_multi_mqa_gptq(
|
||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
bits, groupsize = weights.get_gptq_qparams()
|
||||
bits, groupsize = weights._get_gptq_qparams()
|
||||
|
||||
qweight = qweight.to(weights.device)
|
||||
qzeros = qzeros.to(weights.device)
|
||||
scales = scales.to(weights.device)
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
|
||||
if bias:
|
||||
@ -90,8 +87,6 @@ def _load_multi_mqa_gptq(
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
|
||||
bias = bias.to(weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
else:
|
||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||
|
@ -684,7 +684,6 @@ class FlashCausalLM(Model):
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
config: PretrainedConfig,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
@ -700,7 +699,6 @@ class FlashCausalLM(Model):
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -68,7 +68,6 @@ class FlashLlama(FlashCausalLM):
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
|
@ -65,7 +65,6 @@ class FlashRWSharded(FlashCausalLM):
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=model.transformer.cache_size,
|
||||
head_size=model.transformer.head_size,
|
||||
|
@ -66,7 +66,6 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=1,
|
||||
head_size=model.transformer.head_size,
|
||||
|
@ -198,7 +198,6 @@ class GalacticaSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -63,7 +63,6 @@ class GPTNeoxSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -8,14 +8,6 @@ from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
||||
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear
|
||||
)
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
class Model(ABC):
|
||||
@ -23,7 +15,6 @@ class Model(ABC):
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
config: PretrainedConfig,
|
||||
requires_padding: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
@ -46,47 +37,6 @@ class Model(ABC):
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
self.config = config
|
||||
|
||||
if config.quantize == "gptq":
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
self.buffers = {}
|
||||
use_exllama_act_order = False
|
||||
max_dq_buffer_size = 1
|
||||
max_inner_outer_dim = 1
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||
|
||||
if submodule.linear.act_order:
|
||||
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
|
||||
|
||||
use_exllama_act_order = True
|
||||
|
||||
if use_exllama_act_order:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
max_total_tokens = 2048
|
||||
else:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
|
||||
|
||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
|
||||
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.check_initialized()
|
||||
|
||||
|
@ -86,7 +86,6 @@ class MPTSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -61,7 +61,6 @@ class OPTSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -58,7 +58,6 @@ class RW(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -63,7 +63,6 @@ class SantaCoder(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -542,7 +542,6 @@ class Seq2SeqLM(Model):
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -73,7 +73,6 @@ class T5Sharded(Seq2SeqLM):
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
self.cache = cache
|
||||
@ -140,6 +141,13 @@ def serve(
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
from text_generation_server.utils.gptq.quant_linear import create_exllama_buffers
|
||||
create_exllama_buffers()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
|
120
server/text_generation_server/utils/gptq/exllama.py
Normal file
120
server/text_generation_server/utils/gptq/exllama.py
Normal file
@ -0,0 +1,120 @@
|
||||
|
||||
import torch
|
||||
from custom_kernels.exllama import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device = "meta")
|
||||
|
||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
"""Construct Q4Matrix, return handle"""
|
||||
return make_q4(qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx if g_idx is not None else none_tensor,
|
||||
device)
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
MAX_DQ = 1
|
||||
MAX_INNER = 1
|
||||
ACT_ORDER = False
|
||||
DEVICE = None
|
||||
|
||||
TEMP_STATE = None
|
||||
TEMP_DQ = None
|
||||
|
||||
def create_exllama_buffers():
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||
|
||||
if ACT_ORDER:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
max_total_tokens = 2048
|
||||
else:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
|
||||
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||
|
||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||
|
||||
class Ex4bitLinear:
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
||||
assert bits == 4
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
||||
self.empty_g_idx = True
|
||||
self.g_idx = None
|
||||
|
||||
assert self.device.type == "cuda"
|
||||
assert self.device.index is not None
|
||||
|
||||
self.q4 = ext_make_q4(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
self.scales,
|
||||
self.g_idx,
|
||||
self.device.index
|
||||
)
|
||||
|
||||
self.height = qweight.shape[0] * 8
|
||||
self.width = qweight.shape[1]
|
||||
|
||||
# Infer groupsize from height of qzeros
|
||||
self.groupsize = None
|
||||
if self.qzeros.shape[0] > 1:
|
||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||
self.act_order = True
|
||||
else:
|
||||
self.act_order = False
|
||||
|
||||
DEVICE = self.qweight.device
|
||||
|
||||
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
|
||||
|
||||
if self.act_order:
|
||||
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
||||
|
||||
ACT_ORDER = True
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x, self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
@ -8,11 +8,6 @@ import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from custom_kernels.exllama import make_q4, q4_matmul
|
||||
except Exception as e:
|
||||
logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}")
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -368,76 +363,3 @@ class QuantLinear(nn.Module):
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device = "meta")
|
||||
|
||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
"""Construct Q4Matrix, return handle"""
|
||||
return make_q4(qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx if g_idx is not None else none_tensor,
|
||||
device)
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
class Ex4bitLinear:
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
assert bits == 4
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
||||
self.empty_g_idx = True
|
||||
self.g_idx = None
|
||||
|
||||
assert self.device.type == "cuda"
|
||||
assert self.device.index is not None
|
||||
|
||||
self.q4 = ext_make_q4(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
self.scales,
|
||||
self.g_idx,
|
||||
self.device.index
|
||||
)
|
||||
|
||||
self.height = qweight.shape[0] * 8
|
||||
self.width = qweight.shape[1]
|
||||
|
||||
# Infer groupsize from height of qzeros
|
||||
self.groupsize = None
|
||||
if self.qzeros.shape[0] > 1:
|
||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||
self.act_order = True
|
||||
else:
|
||||
self.act_order = False
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x, self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
||||
|
@ -15,7 +15,12 @@ except ImportError:
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
HAS_EXLLAMA = True
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||
except ImportError:
|
||||
HAS_EXLLAMA = False
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@ -145,13 +150,15 @@ def get_linear(weight, bias, quantize):
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
if use_triton_kernel or bits != 4:
|
||||
if use_exllama:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
@ -161,8 +168,6 @@ def get_linear(weight, bias, quantize):
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
else:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
@ -2,6 +2,7 @@ from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Weights:
|
||||
@ -127,7 +128,7 @@ class Weights:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
@ -136,52 +137,64 @@ class Weights:
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_triton_kernel = False
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
_, groupsize = self.get_gptq_qparams()
|
||||
|
||||
if g_idx is not None:
|
||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_triton_kernel = True
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
|
||||
if use_triton_kernel:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
else:
|
||||
if groupsize >= 16:
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info("Using exllama kernels")
|
||||
|
||||
|
||||
if use_exllama:
|
||||
if groupsize >= 0:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported")
|
||||
# qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
# scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
else:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
def _get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
@ -194,4 +207,4 @@ class Weights:
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
return bits, groupsize
|
||||
|
Loading…
Reference in New Issue
Block a user