From c9e8b68426288c6692ec843222aae501ee89b155 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sat, 20 Jul 2024 09:56:46 +0200 Subject: [PATCH] missing get_weights implementation --- server/text_generation_server/layers/fp8.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 9ec05bba..cdf16d6b 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -71,6 +71,23 @@ class HybridFP8UnquantLoader(WeightsLoader): self.activation_scale_ub = activation_scale_ub self.to_fp8 = to_fp8 + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + def get_weights_col_packed( self, weights: Weights,