mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from typing import Tuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def grouped_topk(
|
||
|
hidden_states: torch.Tensor,
|
||
|
gating_output: torch.Tensor,
|
||
|
topk: int,
|
||
|
renormalize: bool,
|
||
|
num_expert_group: int = 0,
|
||
|
topk_group: int = 0,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
scores = torch.softmax(gating_output, dim=-1)
|
||
|
num_token = scores.shape[0]
|
||
|
group_scores = (
|
||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||
|
) # [n, n_group]
|
||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||
|
1
|
||
|
] # [n, top_k_group]
|
||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||
|
score_mask = (
|
||
|
group_mask.unsqueeze(-1)
|
||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||
|
.reshape(num_token, -1)
|
||
|
) # [n, e]
|
||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||
|
|
||
|
if renormalize:
|
||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||
|
|
||
|
return topk_weights, topk_ids
|
||
|
|
||
|
|
||
|
def fused_topk(
|
||
|
hidden_states: torch.Tensor,
|
||
|
gating_output: torch.Tensor,
|
||
|
topk: int,
|
||
|
renormalize: bool,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
topk_weights = torch.nn.functional.softmax(
|
||
|
gating_output, dim=1, dtype=torch.float32
|
||
|
)
|
||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||
|
if renormalize:
|
||
|
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||
|
return topk_weights, topk_ids
|