Add fp8 support moe models (#2928)

* Add fp8 support moe models

* flatten condition
This commit is contained in:
Mohit Sharma 2025-01-29 18:26:32 +05:30 committed by GitHub
parent 73b7cf83f6
commit 4ef2e045c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 181 additions and 24 deletions

View File

@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
normalize_e4m3fn_to_e4m3fnuz,
normalize_e4m3fn_to_native_float8,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM
@ -148,7 +148,7 @@ class W8ANFpLoader(WeightsLoader):
)
if self.load_weight_scale and SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
w, weight_scale, input_scale
)

View File

@ -68,12 +68,12 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz(
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if weight.dtype == torch.float8_e4m3fn:
if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
@ -172,7 +172,7 @@ def fp8_quantize(
qweight = qweight.to(qdtype)
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
return qweight, scale
@ -295,7 +295,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
)
if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
w, scale, input_scale
)
@ -390,8 +390,8 @@ class Fp8Linear(torch.nn.Module):
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
weight=qweight, weight_scale=scale, input_scale=input_scale
)
self.dtype = dtype

View File

@ -16,6 +16,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
@ -203,12 +204,16 @@ class SparseMoELayer(nn.Module):
down_proj_name: str = "down_proj",
):
super().__init__()
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
weights.loader.weight_class, UnquantizedWeight
):
cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, HybridFP8UnquantLoader):
cls = (
FP8SparseMoELayer
if weights.loader.to_fp8
else UnquantizedSparseMoELayer
)
elif isinstance(
weights.loader, GPTQMarlinWeightsLoader
) and can_use_marlin_moe_gemm(

View File

@ -0,0 +1,162 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
)
from moe_kernels.fused_moe import fused_moe
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,),
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
)

View File

@ -58,17 +58,7 @@ class UnquantizedSparseMoELayer(nn.Module):
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,