mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add moe-kernels
This commit is contained in:
parent
2264702c01
commit
de35b202c4
@ -276,6 +276,15 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||
python setup.py install
|
||||
|
||||
FROM kernel-builder AS moe-kernels
|
||||
WORKDIR /usr/src
|
||||
ENV MOE_KERNELS_BRANCH=127ec5e0fbd2f22fad2d63fbb559d2449e7b5ddb
|
||||
ENV VLLM_TARGET_DEVICE=rocm
|
||||
RUN git clone https://github.com/mht-sharma/moe-kernels.git && \
|
||||
cd moe-kernels && \
|
||||
git checkout ${MOE_KERNELS_BRANCH} && \
|
||||
python setup.py install
|
||||
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
@ -301,6 +310,9 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
|
||||
# Copy build artifacts from marlin kernels
|
||||
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from moe kernels
|
||||
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
|
@ -24,10 +24,7 @@ from text_generation_server.utils.weights import (
|
||||
UnquantizedWeight,
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
from .fused_moe_rocm import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
elif SYSTEM == "ipex":
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
@ -1,52 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
@ -6,9 +6,7 @@ import torch.nn as nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM == "ipex":
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
|
@ -23,9 +23,7 @@ from typing import Optional, List, Tuple, Any
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM == "ipex":
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
else:
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
|
Loading…
Reference in New Issue
Block a user