mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
194 lines
5.8 KiB
Python
194 lines
5.8 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from typing import Optional, Tuple, Dict, Any
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed
|
||
|
|
||
|
|
||
|
# TODO: Remove the functions once moe_kernel are built for ROCM
|
||
|
def grouped_topk(
|
||
|
hidden_states: torch.Tensor,
|
||
|
gating_output: torch.Tensor,
|
||
|
topk: int,
|
||
|
renormalize: bool,
|
||
|
num_expert_group: int = 0,
|
||
|
topk_group: int = 0,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
scores = torch.softmax(gating_output, dim=-1)
|
||
|
num_token = scores.shape[0]
|
||
|
group_scores = (
|
||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||
|
) # [n, n_group]
|
||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||
|
1
|
||
|
] # [n, top_k_group]
|
||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||
|
score_mask = (
|
||
|
group_mask.unsqueeze(-1)
|
||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||
|
.reshape(num_token, -1)
|
||
|
) # [n, e]
|
||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||
|
|
||
|
if renormalize:
|
||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||
|
|
||
|
return topk_weights, topk_ids
|
||
|
|
||
|
|
||
|
def get_default_config(
|
||
|
M: int,
|
||
|
E: int,
|
||
|
N: int,
|
||
|
K: int,
|
||
|
topk: int,
|
||
|
dtype: Optional[str],
|
||
|
) -> Dict[str, int]:
|
||
|
config = {
|
||
|
"BLOCK_SIZE_M": 64,
|
||
|
"BLOCK_SIZE_N": 64,
|
||
|
"BLOCK_SIZE_K": 32,
|
||
|
"GROUP_SIZE_M": 8,
|
||
|
}
|
||
|
if M <= E:
|
||
|
config = {
|
||
|
"BLOCK_SIZE_M": 16,
|
||
|
"BLOCK_SIZE_N": 32,
|
||
|
"BLOCK_SIZE_K": 64,
|
||
|
"GROUP_SIZE_M": 1,
|
||
|
}
|
||
|
return config
|
||
|
|
||
|
|
||
|
def fused_experts(
|
||
|
hidden_states: torch.Tensor,
|
||
|
w1: torch.Tensor,
|
||
|
w2: torch.Tensor,
|
||
|
topk_weights: torch.Tensor,
|
||
|
topk_ids: torch.Tensor,
|
||
|
inplace: bool = False,
|
||
|
override_config: Optional[Dict[str, Any]] = None,
|
||
|
use_fp8: bool = False,
|
||
|
w1_scale: Optional[torch.Tensor] = None,
|
||
|
w2_scale: Optional[torch.Tensor] = None,
|
||
|
a1_scale: Optional[torch.Tensor] = None,
|
||
|
a2_scale: Optional[torch.Tensor] = None,
|
||
|
):
|
||
|
# Check constraints.
|
||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||
|
|
||
|
import triton.language as tl
|
||
|
from vllm import _custom_ops as ops
|
||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||
|
get_moe_configs,
|
||
|
invoke_fused_moe_kernel,
|
||
|
moe_align_block_size,
|
||
|
)
|
||
|
|
||
|
M, _ = hidden_states.shape
|
||
|
E, N, _ = w1.shape
|
||
|
|
||
|
if override_config:
|
||
|
config = override_config
|
||
|
else:
|
||
|
# First try to load optimal config from the file
|
||
|
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||
|
|
||
|
if configs:
|
||
|
# If an optimal configuration map has been found, look up the
|
||
|
# optimal config
|
||
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||
|
else:
|
||
|
# Else use the default config
|
||
|
config = get_default_config(
|
||
|
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
||
|
)
|
||
|
|
||
|
intermediate_cache1 = torch.empty(
|
||
|
(M, topk_ids.shape[1], N),
|
||
|
device=hidden_states.device,
|
||
|
dtype=hidden_states.dtype,
|
||
|
)
|
||
|
intermediate_cache2 = torch.empty(
|
||
|
(M * topk_ids.shape[1], N // 2),
|
||
|
device=hidden_states.device,
|
||
|
dtype=hidden_states.dtype,
|
||
|
)
|
||
|
intermediate_cache3 = torch.empty(
|
||
|
(M, topk_ids.shape[1], w2.shape[1]),
|
||
|
device=hidden_states.device,
|
||
|
dtype=hidden_states.dtype,
|
||
|
)
|
||
|
|
||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||
|
topk_ids, config["BLOCK_SIZE_M"], E
|
||
|
)
|
||
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||
|
|
||
|
invoke_fused_moe_kernel(
|
||
|
hidden_states,
|
||
|
w1,
|
||
|
intermediate_cache1,
|
||
|
a1_scale,
|
||
|
w1_scale,
|
||
|
topk_weights,
|
||
|
topk_ids,
|
||
|
sorted_token_ids,
|
||
|
expert_ids,
|
||
|
num_tokens_post_padded,
|
||
|
False,
|
||
|
topk_ids.shape[1],
|
||
|
config,
|
||
|
compute_type=compute_type,
|
||
|
use_fp8=use_fp8,
|
||
|
)
|
||
|
|
||
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||
|
|
||
|
invoke_fused_moe_kernel(
|
||
|
intermediate_cache2,
|
||
|
w2,
|
||
|
intermediate_cache3,
|
||
|
a2_scale,
|
||
|
w2_scale,
|
||
|
topk_weights,
|
||
|
topk_ids,
|
||
|
sorted_token_ids,
|
||
|
expert_ids,
|
||
|
num_tokens_post_padded,
|
||
|
True,
|
||
|
1,
|
||
|
config,
|
||
|
compute_type=compute_type,
|
||
|
use_fp8=use_fp8,
|
||
|
)
|
||
|
|
||
|
if inplace:
|
||
|
return torch.sum(
|
||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||
|
dim=1,
|
||
|
out=hidden_states,
|
||
|
)
|
||
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|