mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
mixtral moe fix after upgrade vllm extension ops git
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9281be20c0
commit
a184ce3876
@ -4,7 +4,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
|
||||||
|
import habana_frameworks.torch as htorch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedSparseMoELayer(nn.Module):
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
|
||||||
for i in range(n_experts):
|
for i in range(n_experts):
|
||||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||||
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
return self.hpu_fused_moe(x, gating_output, self.topk)
|
htorch.core.mark_step()
|
||||||
|
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
|
||||||
|
routing_weights, selected_experts = torch.topk(
|
||||||
|
routing_weights, self.topk, dim=-1
|
||||||
|
)
|
||||||
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
routing_weights = routing_weights.to(x.dtype)
|
||||||
|
|
||||||
|
final_hidden_states = self.MoeOp(
|
||||||
|
hidden_states=x,
|
||||||
|
expert_routing_table=selected_experts,
|
||||||
|
router_weights=routing_weights,
|
||||||
|
permuted_weights=True,
|
||||||
|
activation="silu",
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_hidden_states.view(-1, x.shape[1])
|
||||||
|
|
||||||
|
|
||||||
def _load_expert_multi_weights_col(
|
def _load_expert_multi_weights_col(
|
||||||
|
Loading…
Reference in New Issue
Block a user