mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Use hub kernels for MoE/GPTQ-Marlin MoE
This commit is contained in:
parent
aab6141b92
commit
758ff3c598
@ -97,7 +97,6 @@ buildPythonPackage {
|
|||||||
hf-transfer
|
hf-transfer
|
||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
moe-kernels
|
|
||||||
opentelemetry-api
|
opentelemetry-api
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-grpc
|
opentelemetry-instrumentation-grpc
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,7 @@ requires = ["hf-kernels", "setuptools"]
|
|||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.kernels.dependencies]
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-community/moe" = ">=0.0.3"
|
||||||
"kernels-community/quantization" = ">=0.0.2"
|
"kernels-community/quantization" = ">=0.0.2"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
@ -67,7 +68,6 @@ quantize = [
|
|||||||
"texttable>=1.6.7,<2",
|
"texttable>=1.6.7,<2",
|
||||||
"datasets>=2.21,<3",
|
"datasets>=2.21,<3",
|
||||||
]
|
]
|
||||||
moe = [ "moe-kernels" ]
|
|
||||||
attention = [ "attention-kernels" ]
|
attention = [ "attention-kernels" ]
|
||||||
gen = [
|
gen = [
|
||||||
"grpcio-tools>=1.69.0",
|
"grpcio-tools>=1.69.0",
|
||||||
@ -81,12 +81,6 @@ attention-kernels = [
|
|||||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
||||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.7.0/moe_kernels-0.7.0+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Optional, Protocol, runtime_checkable
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from hf_kernels import load_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -27,6 +28,10 @@ from text_generation_server.utils.weights import (
|
|||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
moe_kernels = load_kernel("kernels-community/moe")
|
||||||
|
fused_topk = moe_kernels.fused_topk
|
||||||
|
grouped_topk = moe_kernels.grouped_topk
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from hf_kernels import load_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.layers import moe
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
from text_generation_server.layers.marlin.gptq import (
|
||||||
@ -12,9 +14,9 @@ from text_generation_server.layers.marlin.gptq import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
moe_kernels = load_kernel("kernels-community/moe")
|
||||||
else:
|
else:
|
||||||
fused_marlin_moe = None
|
moe_kernels = None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -32,7 +34,7 @@ def can_use_marlin_moe_gemm(
|
|||||||
):
|
):
|
||||||
return (
|
return (
|
||||||
SYSTEM == "cuda"
|
SYSTEM == "cuda"
|
||||||
and fused_marlin_moe is not None
|
and moe is not None
|
||||||
and has_sm_8_0
|
and has_sm_8_0
|
||||||
and quantize in {"awq", "gptq"}
|
and quantize in {"awq", "gptq"}
|
||||||
and quant_method in {"awq", "gptq"}
|
and quant_method in {"awq", "gptq"}
|
||||||
@ -230,3 +232,115 @@ def _pack_weight(
|
|||||||
moe_weight.perm[expert] = weight.perm
|
moe_weight.perm[expert] = weight.perm
|
||||||
|
|
||||||
return moe_weight
|
return moe_weight
|
||||||
|
|
||||||
|
|
||||||
|
def fused_marlin_moe(
|
||||||
|
*,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
g_idx1: torch.Tensor,
|
||||||
|
g_idx2: torch.Tensor,
|
||||||
|
sort_indices1: torch.Tensor,
|
||||||
|
sort_indices2: torch.Tensor,
|
||||||
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
|
is_k_full: bool,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_bits: int = 8,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
|
weights, w1 and w2, and top-k gating mechanism.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w1.
|
||||||
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w2.
|
||||||
|
- gating_output (torch.Tensor): The output of the gating operation
|
||||||
|
(before softmax).
|
||||||
|
- g_idx1 (torch.Tensor): The first set of act_order indices.
|
||||||
|
- g_idx2 (torch.Tensor): The second set of act_order indices.
|
||||||
|
- sort_indices1 (torch.Tensor): The first act_order input permutation.
|
||||||
|
- sort_indices2 (torch.Tensor): The second act_order input permutation.
|
||||||
|
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||||
|
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||||
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
|
for the kernel configuration.
|
||||||
|
- num_bits (bool): The number of bits in expert weights quantization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
# Check constraints.
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||||
|
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||||
|
num_bits // 2
|
||||||
|
), "Hidden size mismatch w2"
|
||||||
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
|
assert hidden_states.dtype == torch.float16
|
||||||
|
assert num_bits in [4, 8]
|
||||||
|
|
||||||
|
# DeekSeekv2 uses grouped_top_k
|
||||||
|
if use_grouped_topk:
|
||||||
|
assert topk_group is not None
|
||||||
|
assert num_expert_group is not None
|
||||||
|
topk_weights, topk_ids = moe_kernels.grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
elif custom_routing_function is None:
|
||||||
|
topk_weights, topk_ids = moe_kernels.fused_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
return moe_kernels.fused_marlin_moe(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
g_idx1=g_idx1,
|
||||||
|
g_idx2=g_idx2,
|
||||||
|
sort_indices1=sort_indices1,
|
||||||
|
sort_indices2=sort_indices2,
|
||||||
|
w1_zeros=w1_zeros,
|
||||||
|
w2_zeros=w2_zeros,
|
||||||
|
override_config=override_config,
|
||||||
|
num_bits=num_bits,
|
||||||
|
is_k_full=is_k_full,
|
||||||
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from hf_kernels import load_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -8,8 +9,10 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
|||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
elif SYSTEM == "cuda":
|
||||||
|
moe_kernels = load_kernel("kernels-community/moe")
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_moe
|
import moe_kernels
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedSparseMoELayer(nn.Module):
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
@ -63,7 +66,17 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "rocm":
|
||||||
|
return moe_kernels.fused_moe(
|
||||||
|
x,
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.down_proj,
|
||||||
|
gating_output,
|
||||||
|
self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
return self.ipex_fused_moe(
|
return self.ipex_fused_moe(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=gating_output,
|
router_logits=gating_output,
|
||||||
@ -73,7 +86,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
)
|
)
|
||||||
return fused_moe(
|
return moe_kernels.fused_moe(
|
||||||
x,
|
x,
|
||||||
w1=self.gate_up_proj,
|
w1=self.gate_up_proj,
|
||||||
w2=self.down_proj,
|
w2=self.down_proj,
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from hf_kernels import load_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -25,8 +26,10 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
|
elif SYSTEM == "cuda":
|
||||||
|
moe_kernels = load_kernel("kernels-community/moe")
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_moe
|
import moe_kernels
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
@ -510,7 +513,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
topk_group=None,
|
topk_group=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = fused_moe(
|
out = moe_kernels.fused_moe(
|
||||||
x,
|
x,
|
||||||
self.wv1,
|
self.wv1,
|
||||||
self.w2,
|
self.w2,
|
||||||
|
Loading…
Reference in New Issue
Block a user