diff --git a/Dockerfile_intel b/Dockerfile_intel index 5edd8951..0f0d4383 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -118,8 +118,9 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/pti/0.9/lib:/opt/conda/li ENV CCL_ZE_IPC_EXCHANGE=sockets #ENV TORCH_LLM_ALLREDUCE=1 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0 -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 033af6f63745ac748cccdadee5c6140c7971edf6 +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout 1ccf72b2d11cd00b47aef6d6cd054c088aa6f083 RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc,ats-m150' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch # Install benchmarker diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index be40d78a..f0bd76aa 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -25,7 +25,7 @@ from text_generation_server.utils.weights import ( ) if SYSTEM == "ipex": - from intel_extension_for_pytorch.llm.modules import GatedMLPMOE + from .fused_moe_ipex import fused_topk, grouped_topk else: from moe_kernels.fused_moe import fused_topk, grouped_topk @@ -139,10 +139,6 @@ class DenseMoELayer(nn.Module): ) for i in range(self.n_experts) ] - if SYSTEM == "ipex": - self.ipex_fused_moe = GatedMLPMOE( - W13=self.gate_proj, W2=self.down_proj, W3=self.up_proj, use_prepack=True - ) self.process_group = weights.process_group @@ -155,17 +151,6 @@ class DenseMoELayer(nn.Module): input_shape = x.shape x = x.view(-1, input_shape[-1]) - if SYSTEM == "ipex": - return self.ipex_fused_moe( - hidden_states=x, - router_logits=gating_output, - top_k=self.topk, - renormalize=self.renormalize, - use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, - topk_group=self.topk_group, - ) - if self.n_expert_group is not None and self.topk_group is not None: topk_weights, topk_ids = grouped_topk( x, diff --git a/server/text_generation_server/layers/moe/fused_moe_ipex.py b/server/text_generation_server/layers/moe/fused_moe_ipex.py new file mode 100644 index 00000000..e26ff877 --- /dev/null +++ b/server/text_generation_server/layers/moe/fused_moe_ipex.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + topk_weights = torch.nn.functional.softmax( + gating_output, dim=1, dtype=torch.float32 + ) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids