diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 4e83ec9d..c4df0213 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -63,7 +63,7 @@ def normalize_e4m3fn_to_e4m3fnuz( weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - if weight.dtype == torch.float8_e4m3fn: + if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm": # The bits pattern 10000000(-128) represents zero in e4m3fn # but NaN in e4m3fnuz. So here we set it to 0. # https://onnx.ai/onnx/technical/float8.html @@ -381,7 +381,7 @@ class Fp8Linear(torch.nn.Module): log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=qweight, weight_scale=scale + weight=qweight, weight_scale=scale, input_scale=input_scale ) self.dtype = dtype diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index be40d78a..ae9ca6fc 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -16,6 +16,7 @@ from text_generation_server.layers.moe.gptq_marlin import ( can_use_marlin_moe_gemm, ) from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer +from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( @@ -218,12 +219,17 @@ class SparseMoELayer(nn.Module): down_proj_name: str = "down_proj", ): super().__init__() - if ( isinstance(weights.loader, DefaultWeightsLoader) and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): - cls = UnquantizedSparseMoELayer + if ( + isinstance(weights.loader, HybridFP8UnquantLoader) + and weights.loader.to_fp8 + ): + cls = FP8SparseMoELayer + else: + cls = UnquantizedSparseMoELayer elif isinstance( weights.loader, GPTQMarlinWeightsLoader ) and can_use_marlin_moe_gemm( diff --git a/server/text_generation_server/layers/moe/fp8.py b/server/text_generation_server/layers/moe/fp8.py new file mode 100644 index 00000000..4d516fd4 --- /dev/null +++ b/server/text_generation_server/layers/moe/fp8.py @@ -0,0 +1,162 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.fp8 import ( + Fp8Weight, + fp8_quantize, + quant_dtype, + normalize_e4m3fn_to_e4m3fnuz, +) +from moe_kernels.fused_moe import fused_moe + + +class FP8SparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + + ( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.gate_up_proj_input_scale, + ) = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( + _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + use_fp8_w8a8=True, + w1_scale=self.gate_up_proj_weight_scale, + w2_scale=self.down_proj_weight_scale, + a1_scale=self.gate_up_proj_input_scale, + a2_scale=self.down_proj_input_scale, + ) + + +def _load_expert_weights( + get_weight_fn, + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + all_weight_scales = None + max_input_scale = None + + for i in range(n_experts): + weight = get_weight_fn(prefix, i, name, weights) + + assert isinstance(weight, Fp8Weight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=quant_dtype, + device=weight.weight.device, + ) + if all_weight_scales is None: + all_weight_scales = torch.empty( + (n_experts,), + dtype=torch.float32, + device=weight.weight.device, + ) + + if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: + all_weight[i], all_weight_scales[i], current_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + weight.weight, weight.weight_scale, weight.input_scale + ) + ) + if current_input_scale is not None: + if max_input_scale is None or current_input_scale > max_input_scale: + max_input_scale = current_input_scale + else: + all_weight[i], all_weight_scales[i] = fp8_quantize( + weight.weight, scalar=True + ) + + assert all_weight is not None + + return all_weight, all_weight_scales, max_input_scale + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + ) + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights_row(f"{prefix}.{i}.{name}") + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + ) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 3c9bcaba..32326653 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -58,17 +58,7 @@ class UnquantizedSparseMoELayer(nn.Module): ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - if SYSTEM == "rocm": - return fused_moe( - x, - self.gate_up_proj, - self.down_proj, - gating_output, - self.topk, - renormalize=self.renormalize, - inplace=True, - ) - elif SYSTEM == "ipex": + if SYSTEM == "ipex": return self.ipex_fused_moe( hidden_states=x, router_logits=gating_output,