mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Add fp8 support moe models (#2928)
* Add fp8 support moe models * flatten condition
This commit is contained in:
parent
73b7cf83f6
commit
4ef2e045c9
@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import (
|
|||||||
Fp8Weight,
|
Fp8Weight,
|
||||||
_load_scalar_or_matrix_scale,
|
_load_scalar_or_matrix_scale,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_native_float8,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
@ -148,7 +148,7 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.load_weight_scale and SYSTEM == "rocm":
|
if self.load_weight_scale and SYSTEM == "rocm":
|
||||||
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
w, weight_scale, input_scale
|
w, weight_scale, input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,12 +68,12 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
|||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_native_float8(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if weight.dtype == torch.float8_e4m3fn:
|
if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
|
||||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
# https://onnx.ai/onnx/technical/float8.html
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
@ -172,7 +172,7 @@ def fp8_quantize(
|
|||||||
qweight = qweight.to(qdtype)
|
qweight = qweight.to(qdtype)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
|
||||||
|
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
@ -295,7 +295,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
w, scale, input_scale
|
w, scale, input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -390,8 +390,8 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
weight=qweight, weight_scale=scale
|
weight=qweight, weight_scale=scale, input_scale=input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
@ -16,6 +16,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
|
|||||||
can_use_marlin_moe_gemm,
|
can_use_marlin_moe_gemm,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
@ -203,12 +204,16 @@ class SparseMoELayer(nn.Module):
|
|||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
|
||||||
if (
|
weights.loader.weight_class, UnquantizedWeight
|
||||||
isinstance(weights.loader, DefaultWeightsLoader)
|
):
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
|
||||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
|
||||||
cls = UnquantizedSparseMoELayer
|
cls = UnquantizedSparseMoELayer
|
||||||
|
elif isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
cls = (
|
||||||
|
FP8SparseMoELayer
|
||||||
|
if weights.loader.to_fp8
|
||||||
|
else UnquantizedSparseMoELayer
|
||||||
|
)
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
weights.loader, GPTQMarlinWeightsLoader
|
weights.loader, GPTQMarlinWeightsLoader
|
||||||
) and can_use_marlin_moe_gemm(
|
) and can_use_marlin_moe_gemm(
|
||||||
|
162
server/text_generation_server/layers/moe/fp8.py
Normal file
162
server/text_generation_server/layers/moe/fp8.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
fp8_quantize,
|
||||||
|
quant_dtype,
|
||||||
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
)
|
||||||
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
||||||
|
|
||||||
|
class FP8SparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
|
||||||
|
(
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.gate_up_proj_weight_scale,
|
||||||
|
self.gate_up_proj_input_scale,
|
||||||
|
) = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||||
|
_load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=self.gate_up_proj_weight_scale,
|
||||||
|
w2_scale=self.down_proj_weight_scale,
|
||||||
|
a1_scale=self.gate_up_proj_input_scale,
|
||||||
|
a2_scale=self.down_proj_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights(
|
||||||
|
get_weight_fn,
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
all_weight_scales = None
|
||||||
|
max_input_scale = None
|
||||||
|
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = get_weight_fn(prefix, i, name, weights)
|
||||||
|
|
||||||
|
assert isinstance(weight, Fp8Weight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=quant_dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
if all_weight_scales is None:
|
||||||
|
all_weight_scales = torch.empty(
|
||||||
|
(n_experts,),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||||
|
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||||
|
normalize_e4m3fn_to_native_float8(
|
||||||
|
weight.weight, weight.weight_scale, weight.input_scale
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if current_input_scale is not None:
|
||||||
|
if max_input_scale is None or current_input_scale > max_input_scale:
|
||||||
|
max_input_scale = current_input_scale
|
||||||
|
else:
|
||||||
|
all_weight[i], all_weight_scales[i] = fp8_quantize(
|
||||||
|
weight.weight, scalar=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight, all_weight_scales, max_input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def get_weight_fn(prefix, i, name, weights):
|
||||||
|
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||||
|
|
||||||
|
return _load_expert_weights(
|
||||||
|
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||||
|
)
|
@ -58,17 +58,7 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "ipex":
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
self.gate_up_proj,
|
|
||||||
self.down_proj,
|
|
||||||
gating_output,
|
|
||||||
self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
return self.ipex_fused_moe(
|
return self.ipex_fused_moe(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=gating_output,
|
router_logits=gating_output,
|
||||||
|
Loading…
Reference in New Issue
Block a user