From 7be2a5f3464e39a91f11ecb545d431b36b5afe1a Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 29 Jan 2025 11:13:22 +0000 Subject: [PATCH] flatten condition --- .../layers/compressed_tensors/w8an_fp.py | 4 ++-- server/text_generation_server/layers/fp8.py | 8 +++---- .../layers/moe/__init__.py | 21 +++++++++---------- .../text_generation_server/layers/moe/fp8.py | 4 ++-- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py index ebcc06d6..42c5e633 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import ( Fp8Weight, _load_scalar_or_matrix_scale, requantize_with_max_scale, - normalize_e4m3fn_to_e4m3fnuz, + normalize_e4m3fn_to_native_float8, ) from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.import_utils import SYSTEM @@ -148,7 +148,7 @@ class W8ANFpLoader(WeightsLoader): ) if self.load_weight_scale and SYSTEM == "rocm": - w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8( w, weight_scale, input_scale ) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index c4df0213..67b33c98 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -58,7 +58,7 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: return Fp8Linear -def normalize_e4m3fn_to_e4m3fnuz( +def normalize_e4m3fn_to_native_float8( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, @@ -162,7 +162,7 @@ def fp8_quantize( qweight = qweight.to(qdtype) if SYSTEM == "rocm": - qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) + qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale) return qweight, scale @@ -285,7 +285,7 @@ class HybridFP8UnquantLoader(WeightsLoader): ) if SYSTEM == "rocm": - w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, scale, input_scale = normalize_e4m3fn_to_native_float8( w, scale, input_scale ) @@ -380,7 +380,7 @@ class Fp8Linear(torch.nn.Module): if CUTLASS_FP8_AVAILABLE: log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: - qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + qweight, scale, input_scale = normalize_e4m3fn_to_native_float8( weight=qweight, weight_scale=scale, input_scale=input_scale ) diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index ae9ca6fc..3b227e96 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -219,17 +219,16 @@ class SparseMoELayer(nn.Module): down_proj_name: str = "down_proj", ): super().__init__() - if ( - isinstance(weights.loader, DefaultWeightsLoader) - and isinstance(weights.loader.weight_class, UnquantizedWeight) - ) or isinstance(weights.loader, HybridFP8UnquantLoader): - if ( - isinstance(weights.loader, HybridFP8UnquantLoader) - and weights.loader.to_fp8 - ): - cls = FP8SparseMoELayer - else: - cls = UnquantizedSparseMoELayer + if isinstance(weights.loader, DefaultWeightsLoader) and isinstance( + weights.loader.weight_class, UnquantizedWeight + ): + cls = UnquantizedSparseMoELayer + elif isinstance(weights.loader, HybridFP8UnquantLoader): + cls = ( + FP8SparseMoELayer + if weights.loader.to_fp8 + else UnquantizedSparseMoELayer + ) elif isinstance( weights.loader, GPTQMarlinWeightsLoader ) and can_use_marlin_moe_gemm( diff --git a/server/text_generation_server/layers/moe/fp8.py b/server/text_generation_server/layers/moe/fp8.py index 4d516fd4..7ccddb5b 100644 --- a/server/text_generation_server/layers/moe/fp8.py +++ b/server/text_generation_server/layers/moe/fp8.py @@ -8,7 +8,7 @@ from text_generation_server.layers.fp8 import ( Fp8Weight, fp8_quantize, quant_dtype, - normalize_e4m3fn_to_e4m3fnuz, + normalize_e4m3fn_to_native_float8, ) from moe_kernels.fused_moe import fused_moe @@ -112,7 +112,7 @@ def _load_expert_weights( if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: all_weight[i], all_weight_scales[i], current_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( + normalize_e4m3fn_to_native_float8( weight.weight, weight.weight_scale, weight.input_scale ) )