From 5789139c68c236ee2e45658d6174af3fdaf798b4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sat, 20 Jul 2024 09:16:42 +0200 Subject: [PATCH] fix auto conversion --- server/text_generation_server/layers/fp8.py | 11 ++++++++--- .../models/custom_modeling/flash_llama_modeling.py | 14 ++++++-------- .../text_generation_server/utils/quantization.py | 2 +- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 4568f8a0..9ec05bba 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -67,8 +67,9 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): class HybridFP8UnquantLoader(WeightsLoader): """Weight loader that loads FP8 and unquantized Torch tensors.""" - def __init__(self, activation_scale_ub: Optional[float]): + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 def get_weights_col_packed( self, @@ -91,6 +92,8 @@ class HybridFP8UnquantLoader(WeightsLoader): activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) return UnquantizedWeight(w) @@ -111,6 +114,8 @@ class HybridFP8UnquantLoader(WeightsLoader): activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) return UnquantizedWeight(w) @@ -125,6 +130,8 @@ class HybridFP8UnquantLoader(WeightsLoader): activation_scale_ub=self.activation_scale_ub, dtype=weights.dtype, ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) return UnquantizedWeight(w) @@ -186,8 +193,6 @@ class Fp8Linear(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if FBGEMM_MM_AVAILABLE: - log_once(logger.info, "Using FBGEMM fp8 kernels") - qinput, scale = fp8_quantize( input, scale_upper_bound=self.scale_upper_bound ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index df635ff2..f7980d2d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -42,16 +41,15 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import ( - DefaultWeightsLoader, UnquantizedWeight, Weights, ) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader if SYSTEM == "rocm": try: @@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id): @contextmanager def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" weights_loader = weights.weights_loader - if ( - isinstance(weights_loader, DefaultWeightsLoader) - and weights_loader.weight_class is Fp8Weight - ): - weights_loader = DefaultWeightsLoader(UnquantizedWeight) + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) with weights.use_loader(weights_loader): yield diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 8ff6ddf1..a6013361 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -172,6 +172,6 @@ def get_loader( if isinstance(quantizer_config, _FP8QuantizerConfig): activation_scale_ub = quantizer_config.activation_scale_ub - return HybridFP8UnquantLoader(activation_scale_ub) + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") else: raise ValueError(f"Unknown quantization method: {quantize}")