diff --git a/Dockerfile_amd b/Dockerfile_amd index 77d4e613..6bfb2bfc 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -276,6 +276,15 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \ git checkout ${MARLIN_KERNELS_BRANCH} && \ python setup.py install +FROM kernel-builder AS moe-kernels +WORKDIR /usr/src +ENV MOE_KERNELS_BRANCH=127ec5e0fbd2f22fad2d63fbb559d2449e7b5ddb +ENV VLLM_TARGET_DEVICE=rocm +RUN git clone https://github.com/mht-sharma/moe-kernels.git && \ + cd moe-kernels && \ + git checkout ${MOE_KERNELS_BRANCH} && \ + python setup.py install + FROM install_deps AS base-copy # Text Generation Inference base env @@ -301,6 +310,9 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31 # Copy build artifacts from marlin kernels COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from moe kernels +COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages + # Install server COPY proto proto COPY server server diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index a5ae7ff4..be40d78a 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -24,10 +24,7 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, ) -if SYSTEM == "rocm": - from .fused_moe_rocm import grouped_topk - from vllm.model_executor.layers.fused_moe import fused_topk -elif SYSTEM == "ipex": +if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE else: from moe_kernels.fused_moe import fused_topk, grouped_topk diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py deleted file mode 100644 index 68accb99..00000000 --- a/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding=utf-8 -# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.distributed - - -# TODO: Remove the functions once moe_kernel are built for ROCM -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 75af0409..3c9bcaba 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -6,9 +6,7 @@ import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight, Weights -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM == "ipex": +if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE else: from moe_kernels.fused_moe import fused_moe diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 2d1aa96c..aa032782 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -23,9 +23,7 @@ from typing import Optional, List, Tuple, Any from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM -if SYSTEM == "rocm": - from vllm.model_executor.layers.fused_moe import fused_moe -elif SYSTEM == "ipex": +if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE else: from moe_kernels.fused_moe import fused_moe