mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
fix moe in quantization path (#2935)
update ipex xpu to support moe for mixtral Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
2dfe3b3ee6
commit
1d3c9beba8
@ -118,8 +118,9 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/li
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
#ENV TORCH_LLM_ALLREDUCE=1
|
||||
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
|
||||
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
|
||||
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083
|
||||
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||
|
||||
# Install benchmarker
|
||||
|
@ -25,7 +25,7 @@ from text_generation_server.utils.weights import (
|
||||
)
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
@ -139,10 +139,6 @@ class DenseMoELayer(nn.Module):
|
||||
)
|
||||
for i in range(self.n_experts)
|
||||
]
|
||||
if SYSTEM == "ipex":
|
||||
self.ipex_fused_moe = GatedMLPMOE(
|
||||
W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
@ -155,17 +151,6 @@ class DenseMoELayer(nn.Module):
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
return self.ipex_fused_moe(
|
||||
hidden_states=x,
|
||||
router_logits=gating_output,
|
||||
top_k=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
)
|
||||
|
||||
if self.n_expert_group is not None and self.topk_group is not None:
|
||||
topk_weights, topk_ids = grouped_topk(
|
||||
x,
|
||||
|
65
server/text_generation_server/layers/moe/fused_moe_ipex.py
Normal file
65
server/text_generation_server/layers/moe/fused_moe_ipex.py
Normal file
@ -0,0 +1,65 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights = torch.nn.functional.softmax(
|
||||
gating_output, dim=1, dtype=torch.float32
|
||||
)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
Loading…
Reference in New Issue
Block a user