mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Use hub kernels for MoE/GPTQ-Marlin MoE
This commit is contained in:
parent
aab6141b92
commit
758ff3c598
@ -97,7 +97,6 @@ buildPythonPackage {
|
||||
hf-transfer
|
||||
loguru
|
||||
mamba-ssm
|
||||
moe-kernels
|
||||
opentelemetry-api
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-grpc
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,7 @@ requires = ["hf-kernels", "setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-community/moe" = ">=0.0.3"
|
||||
"kernels-community/quantization" = ">=0.0.2"
|
||||
|
||||
[project.scripts]
|
||||
@ -67,7 +68,6 @@ quantize = [
|
||||
"texttable>=1.6.7,<2",
|
||||
"datasets>=2.21,<3",
|
||||
]
|
||||
moe = [ "moe-kernels" ]
|
||||
attention = [ "attention-kernels" ]
|
||||
gen = [
|
||||
"grpcio-tools>=1.69.0",
|
||||
@ -81,12 +81,6 @@ attention-kernels = [
|
||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
@ -27,6 +28,10 @@ from text_generation_server.utils.weights import (
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||
if SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
fused_topk = moe_kernels.fused_topk
|
||||
grouped_topk = moe_kernels.grouped_topk
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers import moe
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
@ -12,9 +14,9 @@ from text_generation_server.layers.marlin.gptq import (
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
else:
|
||||
fused_marlin_moe = None
|
||||
moe_kernels = None
|
||||
|
||||
|
||||
try:
|
||||
@ -32,7 +34,7 @@ def can_use_marlin_moe_gemm(
|
||||
):
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and fused_marlin_moe is not None
|
||||
and moe is not None
|
||||
and has_sm_8_0
|
||||
and quantize in {"awq", "gptq"}
|
||||
and quant_method in {"awq", "gptq"}
|
||||
@ -230,3 +232,115 @@ def _pack_weight(
|
||||
moe_weight.perm[expert] = weight.perm
|
||||
|
||||
return moe_weight
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx1: torch.Tensor,
|
||||
g_idx2: torch.Tensor,
|
||||
sort_indices1: torch.Tensor,
|
||||
sort_indices2: torch.Tensor,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
is_k_full: bool,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_bits: int = 8,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (torch.Tensor): The first set of act_order indices.
|
||||
- g_idx2 (torch.Tensor): The second set of act_order indices.
|
||||
- sort_indices1 (torch.Tensor): The first act_order input permutation.
|
||||
- sort_indices2 (torch.Tensor): The second act_order input permutation.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||
num_bits // 2
|
||||
), "Hidden size mismatch w2"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype == torch.float16
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
topk_weights, topk_ids = moe_kernels.grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = moe_kernels.fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
return moe_kernels.fused_marlin_moe(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
gating_output=gating_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=w1_zeros,
|
||||
w2_zeros=w2_zeros,
|
||||
override_config=override_config,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -8,8 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
elif SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
import moe_kernels
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
@ -63,7 +66,17 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
if SYSTEM == "ipex":
|
||||
if SYSTEM == "rocm":
|
||||
return moe_kernels.fused_moe(
|
||||
x,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
return self.ipex_fused_moe(
|
||||
hidden_states=x,
|
||||
router_logits=gating_output,
|
||||
@ -73,7 +86,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
return fused_moe(
|
||||
return moe_kernels.fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
w2=self.down_proj,
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -25,8 +26,10 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
elif SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
import moe_kernels
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
@ -510,7 +513,7 @@ class BlockSparseMoE(nn.Module):
|
||||
topk_group=None,
|
||||
)
|
||||
else:
|
||||
out = fused_moe(
|
||||
out = moe_kernels.fused_moe(
|
||||
x,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
|
Loading…
Reference in New Issue
Block a user