mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Refactor layers.
This commit is contained in:
parent
59b3ffea14
commit
c84718c8b6
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
|
10
server/text_generation_server/layers/__init__.py
Normal file
10
server/text_generation_server/layers/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from text_generation_server.layers.tensor_parallel import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
from text_generation_server.layers.speculative import SpeculativeHead
|
||||
from text_generation_server.layers.linear import (
|
||||
get_linear,
|
||||
FastLinear,
|
||||
)
|
106
server/text_generation_server/layers/bnb.py
Normal file
106
server/text_generation_server/layers/bnb.py
Normal file
@ -0,0 +1,106 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def warn_deprecate_bnb():
|
||||
logger.warning(
|
||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
||||
)
|
||||
|
||||
|
||||
class Linear8bitLt(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
not memory_efficient_backward
|
||||
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
# Necessary for stacked layers
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
weight.data,
|
||||
has_fp16_weights=has_fp16_weights,
|
||||
requires_grad=has_fp16_weights,
|
||||
)
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
self.weight = Params4bit(
|
||||
weight.data,
|
||||
requires_grad=False,
|
||||
compress_statistics=True,
|
||||
quant_type=quant_type,
|
||||
)
|
||||
self.compute_dtype = None
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, "quant_state", None) is None:
|
||||
print(
|
||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||
)
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(
|
||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||
)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
41
server/text_generation_server/layers/conv.py
Normal file
41
server/text_generation_server/layers/conv.py
Normal file
@ -0,0 +1,41 @@
|
||||
from accelerate import init_empty_weights
|
||||
import torch
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
with init_empty_weights():
|
||||
conv2d = cls(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
conv2d.weight = torch.nn.Parameter(weight)
|
||||
conv2d.bias = torch.nn.Parameter(bias)
|
||||
return conv2d
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_conv2d_no_bias(
|
||||
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||
):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
with init_empty_weights():
|
||||
conv2d = cls(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
conv2d.weight = torch.nn.Parameter(weight)
|
||||
conv2d.bias = None
|
||||
return conv2d
|
||||
|
||||
|
||||
torch.nn.Conv2d.load = load_conv2d
|
||||
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
|
25
server/text_generation_server/layers/eetq.py
Normal file
25
server/text_generation_server/layers/eetq.py
Normal file
@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from EETQ import quant_weights, w8_a16_gemm
|
||||
|
||||
|
||||
class EETQLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
if weight.dtype != torch.float16:
|
||||
weight = weight.to(dtype=torch.float16)
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
|
||||
self.weight = weight.cuda(device)
|
||||
self.scale = scale.cuda(device)
|
||||
self.bias = bias.cuda(device) if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = w8_a16_gemm(input, self.weight, self.scale)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
return output
|
43
server/text_generation_server/layers/fp8.py
Normal file
43
server/text_generation_server/layers/fp8.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
|
||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||
device = weight.device
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
|
||||
class Fp8Linear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = weight.dtype
|
||||
self.qweight, self.scale = fp8_quantize(weight)
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
qinput, scale = fp8_quantize(input)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
self.qweight.t(),
|
||||
out_dtype=self.dtype,
|
||||
scale_a=scale,
|
||||
scale_b=self.scale,
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
39
server/text_generation_server/layers/gptq/__init__.py
Normal file
39
server/text_generation_server/layers/gptq/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
import os
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_ROCM_SYSTEM,
|
||||
)
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
if V2:
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
QuantLinear as ExllamaQuantLinear,
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "2"
|
||||
else:
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
Ex4bitLinear as ExllamaQuantLinear,
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "1"
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from text_generation_server.layers.gptq.gptq.quant_linear import QuantLinear
|
356
server/text_generation_server/layers/gptq/quant_linear.py
Normal file
356
server/text_generation_server/layers/gptq/quant_linear.py
Normal file
@ -0,0 +1,356 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_fwd
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from . import custom_autotune
|
||||
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
g_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(
|
||||
scales_ptrs + g_idx[:, None] * stride_scales
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(
|
||||
zeros_ptrs + g_idx[:, None] * stride_zeros
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty(
|
||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
178
server/text_generation_server/layers/layernorm.py
Normal file
178
server/text_generation_server/layers/layernorm.py
Normal file
@ -0,0 +1,178 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from accelerate import init_empty_weights
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
with init_empty_weights():
|
||||
ln = cls(weight.shape, eps=eps)
|
||||
|
||||
ln.weight = torch.nn.Parameter(weight)
|
||||
ln.bias = torch.nn.Parameter(bias)
|
||||
return ln
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
with init_empty_weights():
|
||||
ln = cls(weight.shape, eps=eps)
|
||||
|
||||
ln.weight = torch.nn.Parameter(weight)
|
||||
ln.bias = None
|
||||
return ln
|
||||
|
||||
|
||||
torch.nn.LayerNorm.load = load_layer_norm
|
||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
elif IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if IS_XPU_SYSTEM:
|
||||
res_out = hidden_states
|
||||
out = ipex.llm.functional.add_layer_norm(
|
||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||
)
|
||||
if residual is not None:
|
||||
res_out = residual
|
||||
return out, res_out
|
||||
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
|
||||
class FastRMSNorm(nn.Module):
|
||||
def __init__(self, weight: torch.Tensor, eps: float):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if IS_XPU_SYSTEM:
|
||||
residual_out = hidden_states
|
||||
out = ipex.llm.functional.add_rms_norm(
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
True,
|
||||
)
|
||||
if residual is not None:
|
||||
residual_out = residual
|
||||
return out, residual_out
|
||||
elif hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
elif IS_CUDA_SYSTEM:
|
||||
# faster post attention rms norm
|
||||
(
|
||||
normed_hidden_states,
|
||||
res,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
151
server/text_generation_server/layers/linear.py
Normal file
151
server/text_generation_server/layers/linear.py
Normal file
@ -0,0 +1,151 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = torch.nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
elif quantize == "eetq":
|
||||
try:
|
||||
from text_generation_server.layers.eetq import EETQLinear
|
||||
|
||||
linear = EETQLinear(weight, bias)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import Fp8Linear
|
||||
|
||||
linear = Fp8Linear(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import (
|
||||
warn_deprecate_bnb,
|
||||
Linear8bitLt,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
||||
)
|
||||
warn_deprecate_bnb()
|
||||
linear = Linear8bitLt(
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import Linear4bit
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
||||
)
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="fp4",
|
||||
)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import Linear4bit
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
||||
)
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
try:
|
||||
from text_generation_server.utils.gptq.quant_linear import (
|
||||
ExllamaQuantLinear,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
)
|
||||
|
||||
linear = ExllamaQuantLinear(
|
||||
qweight, qzeros, scales, g_idx, bias, bits, groupsize
|
||||
)
|
||||
else:
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx,
|
||||
bias,
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
try:
|
||||
qweight, qzeros, scales, _, bits, groupsize, _ = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
if IS_ROCM_SYSTEM:
|
||||
raise NotImplementedError(
|
||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
)
|
||||
try:
|
||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||
|
||||
linear = WQLinear(
|
||||
w_bit=bits,
|
||||
group_size=groupsize,
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
bias=bias is not None,
|
||||
)
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
182
server/text_generation_server/layers/medusa.py
Normal file
182
server/text_generation_server/layers/medusa.py
Normal file
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Tuple, Optional
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.layers.linear import FastLinear
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(
|
||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||
)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, medusa_config, weights):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(get_speculate())
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, medusa_config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(
|
||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MedusaHeadV1(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, medusa_config, weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MedusaHeadV1(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if input.shape[0] > 128:
|
||||
return logits, None
|
||||
|
||||
speculative_logits = self.medusa(input)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHeadV2(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
self.n_medusa_heads = get_speculate()
|
||||
|
||||
assert medusa_config["medusa_num_layers"] == 1
|
||||
self.linear = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
|
||||
def forward(self, x):
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if x.shape[0] > 128:
|
||||
logits = self.lm_head(x)
|
||||
return logits, None
|
||||
|
||||
size = x.shape[-1]
|
||||
block_size = (size + self.world_size - 1) // self.world_size
|
||||
start = self.rank * block_size
|
||||
stop = (self.rank + 1) * block_size
|
||||
|
||||
x_block = x[:, start:stop]
|
||||
|
||||
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||
medusa_res = self.act(self.linear(x)).reshape(
|
||||
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||
)
|
||||
|
||||
# Apply all residual medusa heads
|
||||
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||
|
||||
# Gather medusa heads
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
|
||||
# Stack x and medusa residual x
|
||||
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||
|
||||
# Compute lm head on x + medusa residual x
|
||||
logits = self.lm_head(stacked_x)
|
||||
|
||||
# Finally, split logits from speculative logits
|
||||
logits, speculative_logits = torch.split(
|
||||
logits, [1, self.n_medusa_heads], dim=-2
|
||||
)
|
||||
# Squeeze added dimension
|
||||
logits = logits.squeeze(-2)
|
||||
|
||||
return logits, speculative_logits
|
421
server/text_generation_server/layers/rotary.py
Normal file
421
server/text_generation_server/layers/rotary.py
Normal file
@ -0,0 +1,421 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {
|
||||
"type": os.environ["ROPE_SCALING"],
|
||||
"factor": float(os.environ["ROPE_FACTOR"]),
|
||||
}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# Such controlflows may add some overhead.
|
||||
if IS_CUDA_SYSTEM:
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = query[..., :rotary_dim]
|
||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
k1 = key[..., :rotary_dim]
|
||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
head_size = query.shape[-1]
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif IS_XPU_SYSTEM:
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=base,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif rope_scaling["type"] == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
attn_factor=1,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
)
|
||||
elif rope_scaling["type"] == "su":
|
||||
short_factor = torch.tensor(
|
||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
short_inv_freq = 1.0 / (
|
||||
short_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
)
|
||||
long_factor = torch.tensor(
|
||||
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
long_inv_freq = 1.0 / (
|
||||
long_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
)
|
||||
|
||||
original_max_position_embeddings = (
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
if max_position_embeddings <= original_max_position_embeddings:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scale = max_position_embeddings / original_max_position_embeddings
|
||||
scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(original_max_position_embeddings)
|
||||
)
|
||||
|
||||
return SuRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
scaling_factor=scaling_factor,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif rope_scaling["type"] == "yarn":
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
attn_factor=1,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
if IS_ROCM_SYSTEM:
|
||||
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||
dtype = torch.float32
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
|
||||
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
short_inv_freq,
|
||||
long_inv_freq,
|
||||
scaling_factor,
|
||||
original_max_position_embeddings,
|
||||
):
|
||||
super(PositionRotaryEmbedding, self).__init__()
|
||||
self.short_inv_freq = short_inv_freq
|
||||
self.long_inv_freq = long_inv_freq
|
||||
self.scaling_factor = scaling_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
if seqlen > self.original_max_position_embeddings:
|
||||
inv_freq = self.long_inv_freq
|
||||
else:
|
||||
inv_freq = self.short_inv_freq
|
||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * (
|
||||
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||
- (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(
|
||||
self.dim, newbase, self.inv_freq.device
|
||||
)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
import math
|
||||
|
||||
|
||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def get_mscale(scale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
device,
|
||||
scaling_factor,
|
||||
*,
|
||||
extrapolation_factor,
|
||||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = float(
|
||||
get_mscale(self.scaling_factor) * self.attn_factor
|
||||
) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
inv_freq_extrapolation = _create_inv_freq(
|
||||
self.dim, self.base, self.inv_freq.device
|
||||
)
|
||||
freqs = 1.0 / inv_freq_extrapolation
|
||||
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||
low, high = find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = (
|
||||
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
|
||||
self.inv_freq = inv_freq
|
||||
self.mscale = float(
|
||||
get_mscale(self.scaling_factor) * self.attn_factor
|
||||
) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
35
server/text_generation_server/layers/speculative.py
Normal file
35
server/text_generation_server/layers/speculative.py
Normal file
@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
|
||||
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||
|
||||
|
||||
class SpeculativeHead(torch.nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
lm_head = None
|
||||
try:
|
||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||
except:
|
||||
medusa = MedusaHeadV2(config, prefix, weights)
|
||||
else:
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.medusa is not None:
|
||||
return self.medusa(input)
|
||||
|
||||
assert self.head is not None
|
||||
logits = self.head(input)
|
||||
return logits, None
|
188
server/text_generation_server/layers/tensor_parallel.py
Normal file
188
server/text_generation_server/layers/tensor_parallel.py
Normal file
@ -0,0 +1,188 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typing import List
|
||||
from text_generation_server.layers.linear import get_linear
|
||||
|
||||
|
||||
class SuperLayer(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq", "eetq"]:
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
if input.shape[0] == 1:
|
||||
world_out = input.new_empty(1, out_dim * world_size)
|
||||
local_out = input.new_empty(1, out_dim)
|
||||
gather_input = local_out
|
||||
else:
|
||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
|
||||
if input.shape[0] == 1:
|
||||
return world_out
|
||||
return world_out.T
|
||||
|
||||
output = super().forward(input)
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_gate_up(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
return cls.load_multi(config, [prefix], weights, bias, dim=0)
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return cls(linear)
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1 and reduce:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(torch.nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
block_size = (num_embeddings + world_size - 1) // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = weight.shape[
|
||||
0
|
||||
] # Usually block_size, might be less in non even vocab_size.
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
"""Additional 0 entry used for masking"""
|
||||
self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||
# translate for [0, self.max_id - self.min_id[
|
||||
input = torch.where(
|
||||
(self.min_id > input) | (input >= self.max_id),
|
||||
self.null_idx,
|
||||
input - self.min_id,
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
@ -32,7 +32,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers import BloomConfig, PreTrainedModel
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -15,7 +15,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -27,7 +27,7 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -26,7 +26,7 @@ from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
if not IS_XPU_SYSTEM:
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
FastLayerNorm,
|
||||
TensorParallelRowLinear,
|
||||
@ -216,7 +216,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
|
||||
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
from text_generation_server.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
|
||||
@ -236,7 +236,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
from text_generation_server.layers.awq import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -27,13 +27,15 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,7 @@ from typing import Optional, List, Tuple
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
FastRMSNorm,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -29,7 +29,7 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -8,7 +8,7 @@ from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -6,7 +6,7 @@ from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
SpeculativeHead,
|
||||
@ -80,13 +80,13 @@ def _load_multi_mqa_gptq(
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
elif quant_method == "awq":
|
||||
g_idx = None
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
from text_generation_server.layers.awq import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = HAS_EXLLAMA
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -29,7 +29,7 @@ from text_generation_server.models.custom_modeling.vlm import (
|
||||
)
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -47,7 +47,7 @@ from text_generation_server.models.custom_modeling.idefics_vision import (
|
||||
from text_generation_server.models.custom_modeling.idefics_perceiver import (
|
||||
IdeficsPerceiverResampler,
|
||||
)
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -41,7 +41,7 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
@ -28,7 +28,7 @@ from transformers.utils import (
|
||||
ModelOutput,
|
||||
logging,
|
||||
)
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -27,7 +27,7 @@ from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ from typing import Optional, Tuple, Any
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastRMSNorm,
|
||||
|
@ -17,7 +17,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -40,7 +40,7 @@ from transformers.modeling_outputs import (
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import GPTNeoXConfig
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -27,7 +27,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import OPTConfig
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional, List, Tuple, Any
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
|
@ -38,7 +38,7 @@ from transformers.utils import (
|
||||
is_torch_fx_proxy,
|
||||
)
|
||||
from transformers import T5Config
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.layers import (
|
||||
from text_generation_server.layers import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
@ -1,97 +0,0 @@
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
|
||||
|
||||
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||
"""
|
||||
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||
Args:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
direction (str): direction of packing, either "column" or "row"
|
||||
Returns:
|
||||
qmatrix (torch.Tensor): packed matrix of integers
|
||||
"""
|
||||
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||
|
||||
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||
|
||||
if direction == "column":
|
||||
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||
|
||||
elif direction == "row":
|
||||
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||
|
||||
qmatrix = qmatrix.to(torch.int32)
|
||||
|
||||
return qmatrix
|
||||
|
||||
|
||||
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||
"""
|
||||
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||
Args:
|
||||
qmatrix (torch.Tensor): matrix of packed integers
|
||||
direction (str): direction of unpacking, either "column" or "row"
|
||||
Returns:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
"""
|
||||
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||
|
||||
if direction == "column":
|
||||
imatrix = torch.bitwise_right_shift(
|
||||
qmatrix[:, :, None], shifts[None, None, :]
|
||||
).view(qmatrix.shape[0], -1)
|
||||
|
||||
elif direction == "row":
|
||||
imatrix = torch.bitwise_right_shift(
|
||||
qmatrix[:, None, :], shifts[None, :, None]
|
||||
).view(-1, qmatrix.shape[-1])
|
||||
|
||||
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||
|
||||
return imatrix
|
||||
|
||||
|
||||
def apply_order(
|
||||
imatrix: torch.Tensor,
|
||||
direction: str = "column",
|
||||
order: List[int] = AWQ_PACK_ORDER,
|
||||
):
|
||||
"""
|
||||
Applies the order to a 4-bit integer matrix.
|
||||
Args:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
direction (str): direction of applying order, either "column" or "row"
|
||||
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||
Returns:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
"""
|
||||
if direction == "column":
|
||||
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||
elif direction == "row":
|
||||
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||
|
||||
return imatrix
|
||||
|
||||
|
||||
def fast_awq_to_gptq(qweight, qzeros):
|
||||
# awq uses column packing for both weights and zeros
|
||||
izeros = unpack(qzeros, direction="column")
|
||||
iweights = unpack(qweight, direction="column")
|
||||
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||
izeros = izeros - 1
|
||||
# exllama uses row packing for weights and column packing for zeros
|
||||
qzeros = pack(izeros, direction="column")
|
||||
qweight = pack(iweights, direction="row")
|
||||
|
||||
return qweight, qzeros
|
@ -1,50 +0,0 @@
|
||||
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import awq_inference_engine # with CUDA kernels
|
||||
|
||||
|
||||
# class ScaledActivation(nn.Module):
|
||||
# def __init__(self, module, scales):
|
||||
# super().__init__()
|
||||
# self.act = module
|
||||
# self.scales = nn.Parameter(scales.data)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||
|
||||
|
||||
class WQLinear(nn.Module):
|
||||
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
||||
super().__init__()
|
||||
|
||||
if w_bit not in [4]:
|
||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||
|
||||
self.in_features = qweight.shape[0]
|
||||
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||
|
||||
self.w_bit = w_bit
|
||||
self.group_size = group_size if group_size != -1 else self.in_features
|
||||
# quick sanity check (make sure aligment)
|
||||
assert self.in_features % self.group_size == 0
|
||||
assert self.out_features % (32 // self.w_bit) == 0
|
||||
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
if bias:
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.out_features,)
|
||||
out = awq_inference_engine.gemm_forward_cuda(
|
||||
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
@ -1,359 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from . import custom_autotune
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
g_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(
|
||||
scales_ptrs + g_idx[:, None] * stride_scales
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(
|
||||
zeros_ptrs + g_idx[:, None] * stride_zeros
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
except:
|
||||
print("triton not installed.")
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty(
|
||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
File diff suppressed because it is too large
Load Diff
@ -171,7 +171,7 @@ class Weights:
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
from text_generation_server.layers.awq import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
@ -227,7 +227,7 @@ class Weights:
|
||||
|
||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
from text_generation_server.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||
@ -242,7 +242,7 @@ class Weights:
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
from text_generation_server.layers.awq import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
@ -321,7 +321,7 @@ class Weights:
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||
from text_generation_server.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
@ -348,7 +348,7 @@ class Weights:
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
from text_generation_server.layers.awq import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user