From e2454dba40e15c66e0493e2a22ab19493b644788 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 3 Dec 2024 15:12:18 +0000 Subject: [PATCH] (feat) convert tscales to tensorwise --- server/text_generation_server/layers/fp8.py | 56 +++++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 1e5c8b3d..03e3660e 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -79,6 +79,32 @@ def normalize_e4m3fn_to_e4m3fnuz( return weight, weight_scale, input_scale +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) + weight[start:end, :], max_w_scale_normalized = fp8_quantize( + weight_dq, max_w_scale + ) + start = end + + return weight, max_w_scale_normalized + + def fp8_quantize( weight: torch.Tensor, scale: Optional[torch.Tensor] = None, @@ -145,13 +171,15 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) .reshape(-1) .expand(w.shape[0]) - ) + ).max() input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) return Fp8Weight( weight=w, @@ -185,7 +213,7 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]) + scale = scale.reshape(-1).expand(w.shape[0]).max() input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -227,9 +255,15 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + .max() + .unsqueeze(0) for p, shape in zip(prefixes, shapes) ] - scale = torch.cat(scale, dim=0).reshape(-1) + scale = torch.cat(scale).to(weights.device) + + logical_widths = [x[0] for x in shapes] + + w, scale = requantize_with_max_scale(w, scale, logical_widths) input_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) @@ -263,12 +297,14 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) .reshape(-1) .expand(w.shape[0]) - ) + ).max() input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) return Fp8Weight( weight=w,