Use hub kernels for MoE/GPTQ-Marlin MoE

This commit is contained in:
Daniël de Kok 2025-01-28 12:51:45 +00:00
parent aab6141b92
commit 758ff3c598
7 changed files with 3463 additions and 17 deletions

View File

@ -97,7 +97,6 @@ buildPythonPackage {
hf-transfer
loguru
mamba-ssm
moe-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,7 @@ requires = ["hf-kernels", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/moe" = ">=0.0.3"
"kernels-community/quantization" = ">=0.0.2"
[project.scripts]
@ -67,7 +68,6 @@ quantize = [
"texttable>=1.6.7,<2",
"datasets>=2.21,<3",
]
moe = [ "moe-kernels" ]
attention = [ "attention-kernels" ]
gen = [
"grpcio-tools>=1.69.0",
@ -81,12 +81,6 @@ attention-kernels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
]
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

View File

@ -1,5 +1,6 @@
from typing import Optional, Protocol, runtime_checkable
from hf_kernels import load_kernel
import torch
import torch.nn as nn
from loguru import logger
@ -27,6 +28,10 @@ from text_generation_server.utils.weights import (
if SYSTEM == "ipex":
from .fused_moe_ipex import fused_topk, grouped_topk
if SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe")
fused_topk = moe_kernels.fused_topk
grouped_topk = moe_kernels.grouped_topk
else:
from moe_kernels.fused_moe import fused_topk, grouped_topk

View File

@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Callable, Dict, List, Optional
from hf_kernels import load_kernel
import torch
import torch.nn as nn
from text_generation_server.layers import moe
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
@ -12,9 +14,9 @@ from text_generation_server.layers.marlin.gptq import (
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
moe_kernels = load_kernel("kernels-community/moe")
else:
fused_marlin_moe = None
moe_kernels = None
try:
@ -32,7 +34,7 @@ def can_use_marlin_moe_gemm(
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and moe is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
@ -230,3 +232,115 @@ def _pack_weight(
moe_weight.perm[expert] = weight.perm
return moe_weight
def fused_marlin_moe(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
sort_indices1: torch.Tensor,
sort_indices2: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
is_k_full: bool,
topk: int,
renormalize: bool,
num_bits: int = 8,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- sort_indices1 (torch.Tensor): The first act_order input permutation.
- sort_indices2 (torch.Tensor): The second act_order input permutation.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = moe_kernels.grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = moe_kernels.fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
return moe_kernels.fused_marlin_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
gating_output=gating_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=w1_zeros,
w2_zeros=w2_zeros,
override_config=override_config,
num_bits=num_bits,
is_k_full=is_k_full,
)

View File

@ -1,5 +1,6 @@
from typing import Optional
from hf_kernels import load_kernel
import torch
import torch.nn as nn
@ -8,8 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe")
else:
from moe_kernels.fused_moe import fused_moe
import moe_kernels
class UnquantizedSparseMoELayer(nn.Module):
@ -63,7 +66,17 @@ class UnquantizedSparseMoELayer(nn.Module):
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "ipex":
if SYSTEM == "rocm":
return moe_kernels.fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
@ -73,7 +86,7 @@ class UnquantizedSparseMoELayer(nn.Module):
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return fused_moe(
return moe_kernels.fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from hf_kernels import load_kernel
import torch
import torch.distributed
@ -25,8 +26,10 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe")
else:
from moe_kernels.fused_moe import fused_moe
import moe_kernels
from text_generation_server.layers.attention import (
paged_attention,
@ -510,7 +513,7 @@ class BlockSparseMoE(nn.Module):
topk_group=None,
)
else:
out = fused_moe(
out = moe_kernels.fused_moe(
x,
self.wv1,
self.w2,