mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
* style * update torch * ix issues * fix clone * revert mkl * added custom PA * style * fix style * style * hide env vart * fix mixtral model * add skinny kernel and merge fixes * fixed style * fix issue for sliding window models * addressed review comments * fix import * improved error messag * updated default value * remove import * fix imports after rebase * float16 dep * improve dockerfile * cleaned dockerfile
243 lines
7.6 KiB
Python
243 lines
7.6 KiB
Python
from typing import Optional, Protocol, runtime_checkable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from loguru import logger
|
|
from transformers.activations import ACT2FN
|
|
|
|
from text_generation_server.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelRowLinear,
|
|
)
|
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
from text_generation_server.utils.log import log_once
|
|
from text_generation_server.utils.weights import (
|
|
DefaultWeightsLoader,
|
|
UnquantizedWeight,
|
|
Weights,
|
|
)
|
|
|
|
if SYSTEM == "rocm":
|
|
from .fused_moe_rocm import grouped_topk
|
|
from vllm.model_executor.layers.fused_moe import fused_topk
|
|
elif SYSTEM != "ipex":
|
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
|
|
|
|
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
|
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
|
# class inheritance is whacky.
|
|
|
|
|
|
@runtime_checkable
|
|
class MoELayer(Protocol):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
n_expert_group: Optional[int],
|
|
n_experts: int,
|
|
prefix: str,
|
|
renormalize: bool,
|
|
topk: int,
|
|
topk_group: Optional[int],
|
|
weights: Weights,
|
|
gate_proj_name: str = "gate_proj",
|
|
up_proj_name: str = "up_proj",
|
|
down_proj_name: str = "down_proj",
|
|
hidden_act: str = "silu",
|
|
): ...
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
|
) -> torch.Tensor: ...
|
|
|
|
|
|
class DenseMoELayer(nn.Module):
|
|
"""
|
|
Layer for MoE that applies *all* experts to each tokens and then weights
|
|
their outputs based on the calculated routing. This layer is much slower
|
|
than `SparseMoELayer` and should only be used when no fused kernels are
|
|
available (e.g. for unsupported quantizers).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
n_expert_group: Optional[int],
|
|
n_experts: int,
|
|
prefix: str,
|
|
renormalize: bool,
|
|
topk: int,
|
|
topk_group: Optional[int],
|
|
weights: Weights,
|
|
gate_proj_name: str = "gate_proj",
|
|
up_proj_name: str = "up_proj",
|
|
down_proj_name: str = "down_proj",
|
|
hidden_act: str = "silu",
|
|
):
|
|
super().__init__()
|
|
|
|
log_once(
|
|
logger.info,
|
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
|
)
|
|
|
|
assert (n_expert_group is None) == (
|
|
topk_group is None
|
|
), "n_expert_group and topk_group must both be None or have some value"
|
|
|
|
self.n_expert_group = n_expert_group
|
|
self.n_experts = n_experts
|
|
self.renormalize = renormalize
|
|
self.topk = topk
|
|
self.topk_group = topk_group
|
|
|
|
if "gelu" in hidden_act:
|
|
self.act = lambda x: torch.nn.functional.gelu(
|
|
x,
|
|
approximate=(
|
|
"tanh"
|
|
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
|
else "none"
|
|
),
|
|
)
|
|
elif "silu" in hidden_act:
|
|
self.act = torch.nn.functional.silu
|
|
else:
|
|
self.act = ACT2FN[hidden_act]
|
|
|
|
self.gate_proj = [
|
|
TensorParallelColumnLinear.load(
|
|
None,
|
|
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
for i in range(self.n_experts)
|
|
]
|
|
self.up_proj = [
|
|
TensorParallelColumnLinear.load(
|
|
None,
|
|
prefix=f"{prefix}.{i}.{up_proj_name}",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
for i in range(self.n_experts)
|
|
]
|
|
self.down_proj = [
|
|
TensorParallelRowLinear.load(
|
|
None,
|
|
prefix=f"{prefix}.{i}.{down_proj_name}",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
for i in range(self.n_experts)
|
|
]
|
|
|
|
self.process_group = weights.process_group
|
|
|
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
x: (sequence_length, model_dim)
|
|
gating_output: (sequence_length, n_experts)
|
|
"""
|
|
# optional reshape
|
|
input_shape = x.shape
|
|
x = x.view(-1, input_shape[-1])
|
|
|
|
if self.n_expert_group is not None and self.topk_group is not None:
|
|
topk_weights, topk_ids = grouped_topk(
|
|
x,
|
|
gating_output,
|
|
self.topk,
|
|
renormalize=self.renormalize,
|
|
num_expert_group=self.n_expert_group,
|
|
topk_group=self.topk_group,
|
|
)
|
|
else:
|
|
topk_weights, topk_ids = fused_topk(
|
|
x, gating_output, self.topk, self.renormalize
|
|
)
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
weights = torch.zeros(
|
|
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
|
)
|
|
|
|
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
|
|
|
out = torch.zeros_like(x)
|
|
for i in range(self.n_experts):
|
|
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
|
h = self.down_proj[i](h, reduce=False)
|
|
out += h * weights[:, i].view(-1, 1)
|
|
|
|
return out
|
|
|
|
|
|
class SparseMoELayer(nn.Module):
|
|
"""
|
|
Layer for MoE that uses fused kernels to only apply the active experts
|
|
for each token (rather than applying all experts and selecting the
|
|
outputs of active experts).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
n_expert_group: Optional[int],
|
|
n_experts: int,
|
|
prefix: str,
|
|
renormalize: bool,
|
|
topk: int,
|
|
topk_group: Optional[int],
|
|
weights: Weights,
|
|
gate_proj_name: str = "gate_proj",
|
|
up_proj_name: str = "up_proj",
|
|
down_proj_name: str = "down_proj",
|
|
):
|
|
super().__init__()
|
|
|
|
if (
|
|
isinstance(weights.loader, DefaultWeightsLoader)
|
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
|
cls = UnquantizedSparseMoELayer
|
|
# Once we wire up GPTQ-Marlin MoE:
|
|
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
|
# cls = GPTQMarlinSparseMoELayer
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
|
)
|
|
|
|
self.moe = cls(
|
|
n_expert_group=n_expert_group,
|
|
n_experts=n_experts,
|
|
prefix=prefix,
|
|
renormalize=renormalize,
|
|
topk=topk,
|
|
topk_group=topk_group,
|
|
weights=weights,
|
|
gate_proj_name=gate_proj_name,
|
|
up_proj_name=up_proj_name,
|
|
down_proj_name=down_proj_name,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
|
return self.moe(x, gating_output=gating_output)
|
|
|
|
@staticmethod
|
|
def is_supported(weights: Weights) -> bool:
|
|
return (
|
|
(
|
|
isinstance(weights.loader, DefaultWeightsLoader)
|
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
|
)
|
|
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
|
# Once we wire up GPTQ-Marlin MoE:
|
|
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
|
)
|