From 69773767c50a19f6b288fe0ee63ca8f782dd1dd3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 24 Mar 2025 20:21:45 -0700 Subject: [PATCH] enable fp8 Signed-off-by: Wang, Yi A --- .../server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/fp8.py | 103 ++++++++---------- 2 files changed, 45 insertions(+), 59 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 569e2e5ba..53837ef71 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -101,6 +101,7 @@ def serve( "bitsandbytes-fp4", "gptq", "awq", + "fp8", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index e37c49839..6c8d637e5 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Optional, Tuple, Type, Union, List import torch -from loguru import logger from text_generation_server.utils.weights import ( Weight, @@ -10,18 +9,16 @@ from text_generation_server.utils.weights import ( UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_once -quantization = None +from vllm_hpu_extension.ops import scaled_fp8_quant +from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 +import habana_frameworks.torch.utils.experimental as htexp + w8a8_block_fp8_matmul = None per_token_group_quant_fp8 = None - quant_dtype: torch.dtype = torch.float8_e4m3fn -CUTLASS_FP8_AVAILABLE = False - - def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. @@ -43,7 +40,13 @@ def per_tensor_dequantize( inv_scale: Union[float, torch.Tensor], dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - fake_qweight = tensor.to(dtype) + device = tensor.device + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + # dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to("cpu") + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale return dq_weight @@ -55,7 +58,10 @@ def requantize_with_max_scale( dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. - max_w_scale = weight_scale.max().float() + max_w_scale = weight_scale.max() + + if is_hpu_gaudi2(): + max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor() start = 0 for idx, logical_width in enumerate(logical_widths): @@ -84,37 +90,16 @@ def fp8_quantize( argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can be used without modification). """ - if quantization is not None: - shape = weight.shape - qweight, scale = quantization.scaled_fp8_quant( - weight.reshape(-1, shape[-1]), - scale=scale, - scale_ub=scale_upper_bound, - # TODO: don't do this when we have to use the Torch kernel. - use_per_token_if_dynamic=not scalar, - ) + shape = weight.shape + qweight, scale = scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + scale=scale, + scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, + ) - return qweight.reshape(shape), scale - - finfo = torch.finfo(qdtype) - - if scale is None: - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) - scale = scale.float().reciprocal() - else: - # Use reciprocal to avoid more expensive division. - qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) - - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) - - return qweight, scale + return qweight.reshape(shape), scale class HybridFP8UnquantLoader(WeightsLoader): @@ -153,6 +138,10 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, @@ -201,6 +190,10 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) input_scale = input_scale.reshape(-1).max() + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, @@ -259,6 +252,11 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -296,7 +294,10 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(0), logical_widths, weights.dtype + ) return Fp8Weight( weight=w, weight_scale=scale, @@ -353,27 +354,19 @@ class Fp8Linear(torch.nn.Module): weight_block_size: Optional[List[int]] = None, ) -> None: super().__init__() - if CUTLASS_FP8_AVAILABLE: - log_once(logger.info, "Using cutlass w8a8 kernels") self.dtype = dtype self.qweight = qweight self.scale = scale.float() self.input_scale = input_scale.float() if input_scale is not None else None self.weight_block_size = weight_block_size - - if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: - self.scale_upper_bound = torch.tensor( - scale_upper_bound, dtype=torch.float32, device=qweight.device - ) - else: - self.scale_upper_bound = scale_upper_bound + self.scale_upper_bound = scale_upper_bound self.bias = bias if bias is not None else None @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=True) return cls( qweight=qweight, scale=scale, @@ -434,14 +427,6 @@ class Fp8Linear(torch.nn.Module): if self.bias is not None: output = output + self.bias return output.to(dtype=input.dtype) - if CUTLASS_FP8_AVAILABLE: - # cutlass FP8 supports per-token scales, so get non-scalar scales. - qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound, scalar=False - ) - return quantization.cutlass_scaled_mm( - qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias - ) qinput, scale = fp8_quantize( input, @@ -470,4 +455,4 @@ def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Siz if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1).expand(shape[0]) + return scale.reshape(-1)