From 32d50c2ea747aa626cd6df0f655e093658468798 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 24 Sep 2024 13:57:40 +0200 Subject: [PATCH] Add support for scalar FP8 weight scales (#2550) * Add support for scalar FP8 weight scales * Support LLM compressor FP8 checkpoints on H100 On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype. However, we wouldn't pick up fp8 quantization for models quantized with LLM compressor. This change adds enough parsing to detect if models have FP8-quantized weights. * Remove stray debug print --- server/text_generation_server/layers/fp8.py | 49 ++++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 59b08b55..61dd5115 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ).reshape(-1) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + scale = scale.reshape(-1).expand(w.shape[0]) + return Fp8Weight( weight=w, weight_scale=scale, @@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader): w = [ weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes ] + shapes = [x.shape for x in w] + # Concat then send to the device w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: scale = [ - weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) - for p in prefixes + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + return Fp8Weight( weight=w, weight_scale=scale, @@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_tensor( - f"{prefix}.weight_scale", to_dtype=False - ).reshape(-1) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -182,6 +196,9 @@ class Fp8Weight(Weight): def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + # This is not checked by the fbgemm kernels, but they require contiguous + # memory. Can be non-contiguous when we e.g. expand from scalars. + self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype ) @@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_fp8(cls, weight, scale, input_scale, bias, dtype): + if FBGEMM_DYN_AVAILABLE: + # fbgemm needs float32 scales. + scale = scale.float() return cls( qweight=weight, scale=scale, @@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module): bias=self.bias, ) return output + + +def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): + scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: + scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + return scale.reshape(-1).expand(shape[0])