mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
* Move to moe-kernels package and switch to common MoE layer This change introduces the new `moe-kernels` package: - Add `moe-kernels` as a dependency. - Introduce a `SparseMoELayer` module that can be used by MoE models. - Port over Mixtral and Deepseek. * Make `cargo check` pass * Update runner
77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
|
from text_generation_server.utils.weights import (
|
|
DefaultWeightsLoader,
|
|
UnquantizedWeight,
|
|
Weights,
|
|
)
|
|
|
|
|
|
class SparseMoELayer(nn.Module):
|
|
"""
|
|
Layer for MoE that uses fused kernels to only apply the active experts
|
|
for each token (rather than applying all experts and selecting the
|
|
outputs of active experts).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
n_expert_group: Optional[int],
|
|
n_experts: int,
|
|
prefix: str,
|
|
renormalize: bool,
|
|
topk: int,
|
|
topk_group: Optional[int],
|
|
weights: Weights,
|
|
gate_proj_name: str = "gate_proj",
|
|
up_proj_name: str = "up_proj",
|
|
down_proj_name: str = "down_proj",
|
|
):
|
|
super().__init__()
|
|
|
|
if (
|
|
isinstance(weights.loader, DefaultWeightsLoader)
|
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
|
cls = UnquantizedSparseMoELayer
|
|
# Once we wire up GPTQ-Marlin MoE:
|
|
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
|
# cls = GPTQMarlinSparseMoELayer
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
|
)
|
|
|
|
self.moe = cls(
|
|
n_expert_group=n_expert_group,
|
|
n_experts=n_experts,
|
|
prefix=prefix,
|
|
renormalize=renormalize,
|
|
topk=topk,
|
|
topk_group=topk_group,
|
|
weights=weights,
|
|
gate_proj_name=gate_proj_name,
|
|
up_proj_name=up_proj_name,
|
|
down_proj_name=down_proj_name,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
|
return self.moe(x, gating_output=gating_output)
|
|
|
|
@staticmethod
|
|
def is_supported(weights: Weights) -> bool:
|
|
return (
|
|
(
|
|
isinstance(weights.loader, DefaultWeightsLoader)
|
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
|
)
|
|
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
|
# Once we wire up GPTQ-Marlin MoE:
|
|
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
|
)
|