enable fp8

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-24 20:21:45 -07:00
parent 8d221b7b79
commit 69773767c5
2 changed files with 45 additions and 59 deletions

View File

@ -101,6 +101,7 @@ def serve(
"bitsandbytes-fp4", "bitsandbytes-fp4",
"gptq", "gptq",
"awq", "awq",
"fp8",
}: }:
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."

View File

@ -2,7 +2,6 @@ from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, List from typing import Optional, Tuple, Type, Union, List
import torch import torch
from loguru import logger
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
Weight, Weight,
@ -10,18 +9,16 @@ from text_generation_server.utils.weights import (
UnquantizedWeight, UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.utils.log import log_once
quantization = None from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn quant_dtype: torch.dtype = torch.float8_e4m3fn
CUTLASS_FP8_AVAILABLE = False
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
""" """
Return an FP8 linear `Module` that is compatible with the current system. Return an FP8 linear `Module` that is compatible with the current system.
@ -43,7 +40,13 @@ def per_tensor_dequantize(
inv_scale: Union[float, torch.Tensor], inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
fake_qweight = tensor.to(dtype) device = tensor.device
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
# dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu")
fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale dq_weight = fake_qweight * inv_scale
return dq_weight return dq_weight
@ -55,7 +58,10 @@ def requantize_with_max_scale(
dtype: torch.dtype, dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation. # Max scale to be used for requanitzation.
max_w_scale = weight_scale.max().float() max_w_scale = weight_scale.max()
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
start = 0 start = 0
for idx, logical_width in enumerate(logical_widths): for idx, logical_width in enumerate(logical_widths):
@ -84,9 +90,8 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification). be used without modification).
""" """
if quantization is not None:
shape = weight.shape shape = weight.shape
qweight, scale = quantization.scaled_fp8_quant( qweight, scale = scaled_fp8_quant(
weight.reshape(-1, shape[-1]), weight.reshape(-1, shape[-1]),
scale=scale, scale=scale,
scale_ub=scale_upper_bound, scale_ub=scale_upper_bound,
@ -96,26 +101,6 @@ def fp8_quantize(
return qweight.reshape(shape), scale return qweight.reshape(shape), scale
finfo = torch.finfo(qdtype)
if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
# Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader): class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors.""" """Weight loader that loads FP8 and unquantized Torch tensors."""
@ -153,6 +138,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.max() .max()
) )
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -201,6 +190,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
to_dtype=False, to_dtype=False,
) )
input_scale = input_scale.reshape(-1).max() input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -259,6 +252,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None else None
) )
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -296,7 +294,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.max() .max()
) )
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -353,27 +354,19 @@ class Fp8Linear(torch.nn.Module):
weight_block_size: Optional[List[int]] = None, weight_block_size: Optional[List[int]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale.float() self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size self.weight_block_size = weight_block_size
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor(
scale_upper_bound, dtype=torch.float32, device=qweight.device
)
else:
self.scale_upper_bound = scale_upper_bound self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@classmethod @classmethod
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE) qweight, scale = fp8_quantize(weight, scalar=True)
return cls( return cls(
qweight=qweight, qweight=qweight,
scale=scale, scale=scale,
@ -434,14 +427,6 @@ class Fp8Linear(torch.nn.Module):
if self.bias is not None: if self.bias is not None:
output = output + self.bias output = output + self.bias
return output.to(dtype=input.dtype) return output.to(dtype=input.dtype)
if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound, scalar=False
)
return quantization.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, input,
@ -470,4 +455,4 @@ def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Siz
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False) scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0]) return scale.reshape(-1)