mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
flatten condition
This commit is contained in:
parent
16162602c2
commit
7be2a5f346
@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import (
|
|||||||
Fp8Weight,
|
Fp8Weight,
|
||||||
_load_scalar_or_matrix_scale,
|
_load_scalar_or_matrix_scale,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_native_float8,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
@ -148,7 +148,7 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.load_weight_scale and SYSTEM == "rocm":
|
if self.load_weight_scale and SYSTEM == "rocm":
|
||||||
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
w, weight_scale, input_scale
|
w, weight_scale, input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
|||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_native_float8(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
@ -162,7 +162,7 @@ def fp8_quantize(
|
|||||||
qweight = qweight.to(qdtype)
|
qweight = qweight.to(qdtype)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
|
||||||
|
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
@ -285,7 +285,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
w, scale, input_scale
|
w, scale, input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
|
||||||
weight=qweight, weight_scale=scale, input_scale=input_scale
|
weight=qweight, weight_scale=scale, input_scale=input_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -219,17 +219,16 @@ class SparseMoELayer(nn.Module):
|
|||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if (
|
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
|
||||||
isinstance(weights.loader, DefaultWeightsLoader)
|
weights.loader.weight_class, UnquantizedWeight
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
):
|
||||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
cls = UnquantizedSparseMoELayer
|
||||||
if (
|
elif isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
isinstance(weights.loader, HybridFP8UnquantLoader)
|
cls = (
|
||||||
and weights.loader.to_fp8
|
FP8SparseMoELayer
|
||||||
):
|
if weights.loader.to_fp8
|
||||||
cls = FP8SparseMoELayer
|
else UnquantizedSparseMoELayer
|
||||||
else:
|
)
|
||||||
cls = UnquantizedSparseMoELayer
|
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
weights.loader, GPTQMarlinWeightsLoader
|
weights.loader, GPTQMarlinWeightsLoader
|
||||||
) and can_use_marlin_moe_gemm(
|
) and can_use_marlin_moe_gemm(
|
||||||
|
@ -8,7 +8,7 @@ from text_generation_server.layers.fp8 import (
|
|||||||
Fp8Weight,
|
Fp8Weight,
|
||||||
fp8_quantize,
|
fp8_quantize,
|
||||||
quant_dtype,
|
quant_dtype,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_native_float8,
|
||||||
)
|
)
|
||||||
from moe_kernels.fused_moe import fused_moe
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ def _load_expert_weights(
|
|||||||
|
|
||||||
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||||
all_weight[i], all_weight_scales[i], current_input_scale = (
|
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_native_float8(
|
||||||
weight.weight, weight.weight_scale, weight.input_scale
|
weight.weight, weight.weight_scale, weight.input_scale
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user