# coding=utf-8 # Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Dict, Any import torch import torch.distributed # TODO: Remove the functions once moe_kernel are built for ROCM def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values ) # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids def get_default_config( M: int, E: int, N: int, K: int, topk: int, dtype: Optional[str], ) -> Dict[str, int]: config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, } if M <= E: config = { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, } return config def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] import triton.language as tl from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( get_moe_configs, invoke_fused_moe_kernel, moe_align_block_size, ) M, _ = hidden_states.shape E, N, _ = w1.shape if override_config: config = override_config else: # First try to load optimal config from the file configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the # optimal config config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config config = get_default_config( M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None ) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype, ) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype, ) intermediate_cache3 = torch.empty( (M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype, ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config["BLOCK_SIZE_M"], E ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 invoke_fused_moe_kernel( hidden_states, w1, intermediate_cache1, a1_scale, w1_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config, compute_type=compute_type, use_fp8=use_fp8, ) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel( intermediate_cache2, w2, intermediate_cache3, a2_scale, w2_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config, compute_type=compute_type, use_fp8=use_fp8, ) if inplace: return torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=hidden_states, ) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)