Refactored a bit.

This commit is contained in:
Nicolas Patry 2023-07-20 17:38:50 +00:00
parent 6bf7090ecd
commit 0860394489
20 changed files with 173 additions and 173 deletions

View File

@ -1,5 +1,5 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_kernels",
@ -14,7 +14,7 @@ setup(
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
CppExtension(
CUDAExtension(
name="custom_kernels.exllama",
sources=[
"custom_kernels/exllama/exllama_ext.cpp",

View File

@ -71,11 +71,8 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits, groupsize = weights.get_gptq_qparams()
bits, groupsize = weights._get_gptq_qparams()
qweight = qweight.to(weights.device)
qzeros = qzeros.to(weights.device)
scales = scales.to(weights.device)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
if bias:
@ -90,8 +87,6 @@ def _load_multi_mqa_gptq(
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")

View File

@ -684,7 +684,6 @@ class FlashCausalLM(Model):
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
config: PretrainedConfig,
num_layers: int,
num_kv_heads: int,
head_size: int,
@ -700,7 +699,6 @@ class FlashCausalLM(Model):
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=False,
dtype=dtype,
device=device,

View File

@ -68,7 +68,6 @@ class FlashLlama(FlashCausalLM):
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,

View File

@ -65,7 +65,6 @@ class FlashRWSharded(FlashCausalLM):
super(FlashRWSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,

View File

@ -66,7 +66,6 @@ class FlashSantacoderSharded(FlashCausalLM):
super(FlashSantacoderSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,

View File

@ -198,7 +198,6 @@ class GalacticaSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -63,7 +63,6 @@ class GPTNeoxSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -8,14 +8,6 @@ from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
from custom_kernels.exllama import prepare_buffers, set_tuning_params
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear
)
B = TypeVar("B", bound=Batch)
class Model(ABC):
@ -23,7 +15,6 @@ class Model(ABC):
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
config: PretrainedConfig,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
@ -46,47 +37,6 @@ class Model(ABC):
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.config = config
if config.quantize == "gptq":
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
use_exllama_act_order = False
max_dq_buffer_size = 1
max_inner_outer_dim = 1
for name, submodule in model.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
if submodule.linear.act_order:
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
use_exllama_act_order = True
if use_exllama_act_order:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
max_total_tokens = 2048
else:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()
self.check_initialized()

View File

@ -86,7 +86,6 @@ class MPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=False,
dtype=dtype,
device=device,

View File

@ -61,7 +61,6 @@ class OPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -58,7 +58,6 @@ class RW(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -63,7 +63,6 @@ class SantaCoder(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -542,7 +542,6 @@ class Seq2SeqLM(Model):
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -73,7 +73,6 @@ class T5Sharded(Seq2SeqLM):
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
@ -140,6 +141,13 @@ def serve(
logger.exception("Error when initializing model")
raise
if quantize == "gptq":
try:
from text_generation_server.utils.gptq.quant_linear import create_exllama_buffers
create_exllama_buffers()
except ImportError:
pass
server = aio.server(
interceptors=[
ExceptionInterceptor(),

View File

@ -0,0 +1,120 @@
import torch
from custom_kernels.exllama import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None
TEMP_STATE = None
TEMP_DQ = None
def create_exllama_buffers():
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
if ACT_ORDER:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
max_total_tokens = 2048
else:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
DEVICE = self.qweight.device
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
if self.act_order:
MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -8,11 +8,6 @@ import torch
from loguru import logger
try:
from custom_kernels.exllama import make_q4, q4_matmul
except Exception as e:
logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}")
try:
import triton
import triton.language as tl
@ -368,76 +363,3 @@ class QuantLinear(nn.Module):
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
assert bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -15,7 +15,12 @@ except ImportError:
from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_EXLLAMA = True
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
except ImportError:
HAS_EXLLAMA = False
from typing import Optional
@ -145,13 +150,15 @@ def get_linear(weight, bias, quantize):
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
if use_triton_kernel or bits != 4:
if use_exllama:
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
else:
linear = QuantLinear(
qweight,
qzeros,
@ -161,8 +168,6 @@ def get_linear(weight, bias, quantize):
bits,
groupsize,
)
else:
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear

View File

@ -2,6 +2,7 @@ from pathlib import Path
from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
class Weights:
@ -127,7 +128,7 @@ class Weights:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
bits, groupsize = self.get_gptq_qparams()
bits, groupsize = self._get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -136,52 +137,64 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_triton_kernel = False
use_exllama = True
bits, groupsize = self._get_gptq_qparams()
if bits != 4:
use_exllama = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
_, groupsize = self.get_gptq_qparams()
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_triton_kernel = True
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
bits, groupsize = self.get_gptq_qparams()
if use_triton_kernel:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
else:
if groupsize >= 16:
from text_generation_server.utils.layers import HAS_EXLLAMA
if use_exllama:
if not HAS_EXLLAMA:
logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
use_exllama = False
else:
logger.info("Using exllama kernels")
if use_exllama:
if groupsize >= 0:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported")
# qzeros = self.get_tensor(f"{prefix}.qzeros")
# scales = self.get_tensor(f"{prefix}.scales")
# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
else:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def get_gptq_qparams(self) -> Tuple[int, int]:
def _get_gptq_qparams(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
@ -194,4 +207,4 @@ class Weights:
except Exception:
raise e
return bits, groupsize
return bits, groupsize