2024-09-30 09:14:32 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
from text_generation_server.utils.weights import Weights
|
|
|
|
from text_generation_server.layers.marlin.gptq import (
|
|
|
|
GPTQMarlinWeight,
|
|
|
|
GPTQMarlinWeightsLoader,
|
|
|
|
)
|
|
|
|
|
|
|
|
if SYSTEM == "cuda":
|
|
|
|
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
|
|
|
else:
|
|
|
|
fused_marlin_moe = None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
major, _minor = torch.cuda.get_device_capability()
|
|
|
|
has_sm_8_0 = major >= 8
|
|
|
|
except Exception:
|
|
|
|
has_sm_8_0 = False
|
|
|
|
|
|
|
|
|
|
|
|
def can_use_marlin_moe_gemm(
|
|
|
|
*,
|
|
|
|
quant_method: str,
|
|
|
|
quantize: str,
|
|
|
|
sym: bool,
|
|
|
|
):
|
|
|
|
return (
|
|
|
|
SYSTEM == "cuda"
|
|
|
|
and fused_marlin_moe is not None
|
|
|
|
and has_sm_8_0
|
2024-10-08 09:56:41 +00:00
|
|
|
and quantize in {"awq", "gptq"}
|
|
|
|
and quant_method in {"awq", "gptq"}
|
|
|
|
# We only support asymmetric quantization for AWQ.
|
|
|
|
and (sym or quant_method == "awq")
|
2024-09-30 09:14:32 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GPTQMarlinMoEWeight:
|
|
|
|
qweight: torch.Tensor
|
|
|
|
qzeros: torch.Tensor
|
|
|
|
scales: torch.Tensor
|
|
|
|
g_idx: torch.Tensor
|
|
|
|
perm: torch.Tensor
|
|
|
|
is_full_k: bool
|
|
|
|
|
|
|
|
|
|
|
|
class GPTQMarlinSparseMoELayer(nn.Module):
|
|
|
|
"""
|
|
|
|
MoE layer that uses a fused GPTQ-Marlin kernel.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
n_expert_group: Optional[int],
|
|
|
|
n_experts: int,
|
|
|
|
prefix: str,
|
|
|
|
renormalize: bool,
|
|
|
|
topk: int,
|
|
|
|
topk_group: Optional[int],
|
|
|
|
weights: Weights,
|
|
|
|
gate_proj_name: str = "gate_proj",
|
|
|
|
up_proj_name: str = "up_proj",
|
|
|
|
down_proj_name: str = "down_proj",
|
2025-01-30 15:40:25 +00:00
|
|
|
scoring_func: Optional[str] = None,
|
|
|
|
e_score_correction_bias: Optional[float] = None,
|
2024-09-30 09:14:32 +00:00
|
|
|
):
|
2025-01-30 15:40:25 +00:00
|
|
|
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled"
|
|
|
|
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
2024-09-30 09:14:32 +00:00
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
if not (
|
2024-10-08 09:56:41 +00:00
|
|
|
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
|
|
|
and can_use_marlin_moe_gemm(
|
|
|
|
quant_method=weights.loader.quant_method,
|
|
|
|
quantize=weights.loader.quantize,
|
|
|
|
sym=weights.loader.sym,
|
|
|
|
)
|
2024-09-30 09:14:32 +00:00
|
|
|
):
|
|
|
|
raise ValueError(
|
2024-10-08 09:56:41 +00:00
|
|
|
f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported"
|
2024-09-30 09:14:32 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
assert (n_expert_group is None) == (
|
|
|
|
topk_group is None
|
|
|
|
), "n_expert_group and topk_group must both be None or have some value"
|
|
|
|
|
|
|
|
self.n_expert_group = n_expert_group
|
|
|
|
self.topk = topk
|
|
|
|
self.topk_group = topk_group
|
|
|
|
self.renormalize = renormalize
|
|
|
|
|
|
|
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
|
|
|
prefix=prefix,
|
|
|
|
n_experts=n_experts,
|
|
|
|
names=[gate_proj_name, up_proj_name],
|
|
|
|
weights=weights,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.down_proj = _load_expert_weights_row(
|
|
|
|
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
|
|
|
|
)
|
|
|
|
|
|
|
|
self.bits = weights.loader.bits
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
|
|
|
return fused_marlin_moe(
|
2024-10-08 09:56:41 +00:00
|
|
|
hidden_states=x,
|
2024-09-30 09:14:32 +00:00
|
|
|
w1=self.gate_up_proj.qweight,
|
|
|
|
w2=self.down_proj.qweight,
|
|
|
|
w1_scale=self.gate_up_proj.scales,
|
|
|
|
w2_scale=self.down_proj.scales,
|
2024-10-08 09:56:41 +00:00
|
|
|
w1_zeros=(
|
|
|
|
self.gate_up_proj.qzeros
|
|
|
|
if self.gate_up_proj.qzeros.numel() > 0
|
|
|
|
else None
|
|
|
|
),
|
|
|
|
w2_zeros=(
|
|
|
|
self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None
|
|
|
|
),
|
|
|
|
g_idx1=self.gate_up_proj.g_idx,
|
|
|
|
g_idx2=self.down_proj.g_idx,
|
|
|
|
sort_indices1=self.gate_up_proj.perm,
|
|
|
|
sort_indices2=self.down_proj.perm,
|
|
|
|
is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k,
|
2024-09-30 09:14:32 +00:00
|
|
|
gating_output=gating_output,
|
|
|
|
topk=self.topk,
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
use_grouped_topk=self.n_expert_group is not None,
|
|
|
|
num_expert_group=self.n_expert_group,
|
|
|
|
topk_group=self.topk_group,
|
|
|
|
num_bits=self.bits,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _load_expert_multi_weights_col(
|
|
|
|
*,
|
|
|
|
prefix: str,
|
|
|
|
n_experts: int,
|
|
|
|
names: List[str],
|
|
|
|
weights: Weights,
|
|
|
|
) -> GPTQMarlinMoEWeight:
|
|
|
|
moe_weight = None
|
|
|
|
for i in range(n_experts):
|
|
|
|
weight = weights.get_multi_weights_col(
|
|
|
|
[f"{prefix}.{i}.{name}" for name in names], 0
|
|
|
|
)
|
|
|
|
assert isinstance(weight, GPTQMarlinWeight)
|
|
|
|
moe_weight = _pack_weight(
|
|
|
|
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
|
|
|
)
|
|
|
|
assert moe_weight is not None
|
|
|
|
return moe_weight
|
|
|
|
|
|
|
|
|
|
|
|
def _load_expert_weights_row(
|
|
|
|
*,
|
|
|
|
prefix: str,
|
|
|
|
n_experts: int,
|
|
|
|
name: str,
|
|
|
|
weights: Weights,
|
|
|
|
) -> GPTQMarlinMoEWeight:
|
|
|
|
moe_weight = None
|
|
|
|
for i in range(n_experts):
|
|
|
|
weight = weights.get_weights_row(
|
|
|
|
f"{prefix}.{i}.{name}",
|
|
|
|
)
|
|
|
|
assert isinstance(weight, GPTQMarlinWeight)
|
|
|
|
moe_weight = _pack_weight(
|
|
|
|
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
|
|
|
)
|
|
|
|
assert moe_weight is not None
|
|
|
|
return moe_weight
|
|
|
|
|
|
|
|
|
|
|
|
def _pack_weight(
|
|
|
|
*,
|
|
|
|
n_experts: int,
|
|
|
|
expert: int,
|
|
|
|
moe_weight: Optional[GPTQMarlinMoEWeight],
|
|
|
|
weight: GPTQMarlinWeight,
|
|
|
|
) -> GPTQMarlinMoEWeight:
|
|
|
|
if moe_weight is None:
|
|
|
|
qweight = torch.empty(
|
|
|
|
(n_experts,) + weight.qweight.shape,
|
|
|
|
dtype=weight.qweight.dtype,
|
|
|
|
device=weight.qweight.device,
|
|
|
|
)
|
|
|
|
qzeros = torch.empty(
|
|
|
|
(n_experts,) + weight.qzeros.shape,
|
|
|
|
dtype=weight.qzeros.dtype,
|
|
|
|
device=weight.qzeros.device,
|
|
|
|
)
|
|
|
|
scales = torch.empty(
|
|
|
|
(n_experts,) + weight.scales.shape,
|
|
|
|
dtype=weight.scales.dtype,
|
|
|
|
device=weight.scales.device,
|
|
|
|
)
|
|
|
|
g_idx = torch.empty(
|
|
|
|
(n_experts,) + weight.g_idx.shape,
|
|
|
|
dtype=weight.g_idx.dtype,
|
|
|
|
device=weight.g_idx.device,
|
|
|
|
)
|
|
|
|
perm = torch.empty(
|
|
|
|
(n_experts,) + weight.perm.shape,
|
|
|
|
dtype=weight.perm.dtype,
|
|
|
|
device=weight.perm.device,
|
|
|
|
)
|
|
|
|
|
|
|
|
moe_weight = GPTQMarlinMoEWeight(
|
|
|
|
qweight=qweight,
|
|
|
|
qzeros=qzeros,
|
|
|
|
scales=scales,
|
|
|
|
g_idx=g_idx,
|
|
|
|
perm=perm,
|
|
|
|
is_full_k=weight.is_full_k,
|
|
|
|
)
|
|
|
|
|
|
|
|
moe_weight.qweight[expert] = weight.qweight
|
|
|
|
moe_weight.qzeros[expert] = weight.qzeros
|
|
|
|
moe_weight.scales[expert] = weight.scales
|
|
|
|
moe_weight.g_idx[expert] = weight.g_idx
|
|
|
|
moe_weight.perm[expert] = weight.perm
|
|
|
|
|
|
|
|
return moe_weight
|