mixtral moe fix after upgrade vllm extension ops git

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-15 19:22:57 -07:00
parent 9281be20c0
commit a184ce3876

View File

@ -4,7 +4,9 @@ import torch
import torch.nn as nn
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
class UnquantizedSparseMoELayer(nn.Module):
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights,
)
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
for i in range(n_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.hpu_fused_moe(x, gating_output, self.topk)
htorch.core.mark_step()
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(
routing_weights, self.topk, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
final_hidden_states = self.MoeOp(
hidden_states=x,
expert_routing_table=selected_experts,
router_weights=routing_weights,
permuted_weights=True,
activation="silu",
)
return final_hidden_states.view(-1, x.shape[1])
def _load_expert_multi_weights_col(