From 1194cdb1ba090507cfcbc079a9c8c0bd84fef1f7 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 10 Dec 2024 10:58:01 +0000 Subject: [PATCH] remove grouped_topk --- .../layers/moe/fused_moe_rocm.py | 52 ------------------- 1 file changed, 52 deletions(-) delete mode 100644 server/text_generation_server/layers/moe/fused_moe_rocm.py diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py deleted file mode 100644 index 68accb99..00000000 --- a/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding=utf-8 -# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.distributed - - -# TODO: Remove the functions once moe_kernel are built for ROCM -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids