flatten condition

This commit is contained in:
Mohit Sharma 2025-01-29 11:13:22 +00:00
parent 16162602c2
commit 7be2a5f346
4 changed files with 18 additions and 19 deletions

View File

@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import (
Fp8Weight, Fp8Weight,
_load_scalar_or_matrix_scale, _load_scalar_or_matrix_scale,
requantize_with_max_scale, requantize_with_max_scale,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_native_float8,
) )
from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -148,7 +148,7 @@ class W8ANFpLoader(WeightsLoader):
) )
if self.load_weight_scale and SYSTEM == "rocm": if self.load_weight_scale and SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
w, weight_scale, input_scale w, weight_scale, input_scale
) )

View File

@ -58,7 +58,7 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
return Fp8Linear return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
@ -162,7 +162,7 @@ def fp8_quantize(
qweight = qweight.to(qdtype) qweight = qweight.to(qdtype)
if SYSTEM == "rocm": if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
return qweight, scale return qweight, scale
@ -285,7 +285,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
) )
if SYSTEM == "rocm": if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( w, scale, input_scale = normalize_e4m3fn_to_native_float8(
w, scale, input_scale w, scale, input_scale
) )
@ -380,7 +380,7 @@ class Fp8Linear(torch.nn.Module):
if CUTLASS_FP8_AVAILABLE: if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels") log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
weight=qweight, weight_scale=scale, input_scale=input_scale weight=qweight, weight_scale=scale, input_scale=input_scale
) )

View File

@ -219,17 +219,16 @@ class SparseMoELayer(nn.Module):
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
): ):
super().__init__() super().__init__()
if ( if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
isinstance(weights.loader, DefaultWeightsLoader) weights.loader.weight_class, UnquantizedWeight
and isinstance(weights.loader.weight_class, UnquantizedWeight) ):
) or isinstance(weights.loader, HybridFP8UnquantLoader): cls = UnquantizedSparseMoELayer
if ( elif isinstance(weights.loader, HybridFP8UnquantLoader):
isinstance(weights.loader, HybridFP8UnquantLoader) cls = (
and weights.loader.to_fp8 FP8SparseMoELayer
): if weights.loader.to_fp8
cls = FP8SparseMoELayer else UnquantizedSparseMoELayer
else: )
cls = UnquantizedSparseMoELayer
elif isinstance( elif isinstance(
weights.loader, GPTQMarlinWeightsLoader weights.loader, GPTQMarlinWeightsLoader
) and can_use_marlin_moe_gemm( ) and can_use_marlin_moe_gemm(

View File

@ -8,7 +8,7 @@ from text_generation_server.layers.fp8 import (
Fp8Weight, Fp8Weight,
fp8_quantize, fp8_quantize,
quant_dtype, quant_dtype,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_native_float8,
) )
from moe_kernels.fused_moe import fused_moe from moe_kernels.fused_moe import fused_moe
@ -112,7 +112,7 @@ def _load_expert_weights(
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = ( all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale weight.weight, weight.weight_scale, weight.input_scale
) )
) )