diff --git a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py index ec158398..58709ec3 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/unquantized.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn from text_generation_server.utils.weights import UnquantizedWeight, Weights -from vllm_hpu_extension.ops import DynamicFusedMOE +from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp +import habana_frameworks.torch as htorch +import torch.nn.functional as F class UnquantizedSparseMoELayer(nn.Module): @@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module): weights=weights, ) - self.hpu_fused_moe = DynamicFusedMOE(n_experts) + self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1) for i in range(n_experts): - self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) - self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i]) + self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i]) + self.MoeOp.w2_list[i].set_weight(self.down_proj[i]) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return self.hpu_fused_moe(x, gating_output, self.topk) + htorch.core.mark_step() + routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk( + routing_weights, self.topk, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + final_hidden_states = self.MoeOp( + hidden_states=x, + expert_routing_table=selected_experts, + router_weights=routing_weights, + permuted_weights=True, + activation="silu", + ) + + return final_hidden_states.view(-1, x.shape[1]) def _load_expert_multi_weights_col(