mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
mixtral moe fix after upgrade vllm extension ops git
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
9281be20c0
commit
a184ce3876
@ -4,7 +4,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
|
||||
import habana_frameworks.torch as htorch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
@ -53,13 +55,29 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.hpu_fused_moe = DynamicFusedMOE(n_experts)
|
||||
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
|
||||
for i in range(n_experts):
|
||||
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
|
||||
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return self.hpu_fused_moe(x, gating_output, self.topk)
|
||||
htorch.core.mark_step()
|
||||
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.topk, dim=-1
|
||||
)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(x.dtype)
|
||||
|
||||
final_hidden_states = self.MoeOp(
|
||||
hidden_states=x,
|
||||
expert_routing_table=selected_experts,
|
||||
router_weights=routing_weights,
|
||||
permuted_weights=True,
|
||||
activation="silu",
|
||||
)
|
||||
|
||||
return final_hidden_states.view(-1, x.shape[1])
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
|
Loading…
Reference in New Issue
Block a user