mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
add moe-kernels
This commit is contained in:
parent
2264702c01
commit
de35b202c4
@ -276,6 +276,15 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
|||||||
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||||
python setup.py install
|
python setup.py install
|
||||||
|
|
||||||
|
FROM kernel-builder AS moe-kernels
|
||||||
|
WORKDIR /usr/src
|
||||||
|
ENV MOE_KERNELS_BRANCH=127ec5e0fbd2f22fad2d63fbb559d2449e7b5ddb
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN git clone https://github.com/mht-sharma/moe-kernels.git && \
|
||||||
|
cd moe-kernels && \
|
||||||
|
git checkout ${MOE_KERNELS_BRANCH} && \
|
||||||
|
python setup.py install
|
||||||
|
|
||||||
FROM install_deps AS base-copy
|
FROM install_deps AS base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
@ -301,6 +310,9 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
|
|||||||
# Copy build artifacts from marlin kernels
|
# Copy build artifacts from marlin kernels
|
||||||
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from moe kernels
|
||||||
|
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
|
@ -24,10 +24,7 @@ from text_generation_server.utils.weights import (
|
|||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
from .fused_moe_rocm import grouped_topk
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
@ -1,52 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
@ -6,9 +6,7 @@ import torch.nn as nn
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
@ -23,9 +23,7 @@ from typing import Optional, List, Tuple, Any
|
|||||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
else:
|
else:
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
Loading…
Reference in New Issue
Block a user