text-generation-inference/server/text_generation_server/layers/moe/__init__.py
Mohit Sharma f9e561eced
Update ROCM libs and improvements (#2579)
* style

* update torch

* ix issues

* fix clone

* revert mkl

* added custom PA

* style

* fix style

* style

* hide env vart

* fix mixtral model

* add skinny kernel and merge fixes

* fixed style

* fix issue for sliding window models

* addressed review comments

* fix import

* improved error messag

* updated default value

* remove import

* fix imports after rebase

* float16 dep

* improve dockerfile

* cleaned dockerfile
2024-09-30 10:54:32 +02:00

243 lines
7.6 KiB
Python

from typing import Optional, Protocol, runtime_checkable
import torch
import torch.nn as nn
from loguru import logger
from transformers.activations import ACT2FN
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
# class inheritance is whacky.
@runtime_checkable
class MoELayer(Protocol):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
): ...
def forward(
self, x: torch.Tensor, *, gating_output: torch.Tensor
) -> torch.Tensor: ...
class DenseMoELayer(nn.Module):
"""
Layer for MoE that applies *all* experts to each tokens and then weights
their outputs based on the calculated routing. This layer is much slower
than `SparseMoELayer` and should only be used when no fused kernels are
available (e.g. for unsupported quantizers).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
):
super().__init__()
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.n_experts = n_experts
self.renormalize = renormalize
self.topk = topk
self.topk_group = topk_group
if "gelu" in hidden_act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
elif "silu" in hidden_act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[hidden_act]
self.gate_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{gate_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.up_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{up_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.down_proj = [
TensorParallelRowLinear.load(
None,
prefix=f"{prefix}.{i}.{down_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gating_output: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,
gating_output,
self.topk,
renormalize=self.renormalize,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
x, gating_output, self.topk, self.renormalize
)
topk_weights = topk_weights.to(x.dtype)
weights = torch.zeros(
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
out = torch.zeros_like(x)
for i in range(self.n_experts):
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
h = self.down_proj[i](h, reduce=False)
out += h * weights[:, i].view(-1, 1)
return out
class SparseMoELayer(nn.Module):
"""
Layer for MoE that uses fused kernels to only apply the active experts
for each token (rather than applying all experts and selecting the
outputs of active experts).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
cls = UnquantizedSparseMoELayer
# Once we wire up GPTQ-Marlin MoE:
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
# cls = GPTQMarlinSparseMoELayer
else:
raise ValueError(
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
)
self.moe = cls(
n_expert_group=n_expert_group,
n_experts=n_experts,
prefix=prefix,
renormalize=renormalize,
topk=topk,
topk_group=topk_group,
weights=weights,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
down_proj_name=down_proj_name,
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.moe(x, gating_output=gating_output)
@staticmethod
def is_supported(weights: Weights) -> bool:
return (
(
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
)
or isinstance(weights.loader, HybridFP8UnquantLoader)
# Once we wire up GPTQ-Marlin MoE:
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
)