text-generation-inference/server/text_generation_server/layers/fp8.py
Daniël de Kok 571ac9b507
Use kernels from the kernel hub (#2988)
* Use Hub kernels for Marlin and cutlass quantization kernels

* Use hub kernels for MoE/GPTQ-Marlin MoE

* Use attention kernels from the Hub

* Cache the kernels in the Docker image

* Update moe kernels

* Support loading local kernels for development

* Support latest moe kernels

* Update to moe 0.1.1

* CI: download locked kernels for server tests

* Fixup some imports

* CI: activate venv

* Fix unused imports

* Nix: add attention/moe/quantization kernels

* Update hf-kernels to 0.1.5

* Update kernels

* Update tgi-nix flake for hf-kernels

* Fix EOF

* Take `load_kernel` out of a frequently-called function

* Hoist another case of kernel loading out of a somewhat hot function

* marlin-kernels -> quantization

* attention -> paged-attention

* EOF fix

* Update hf-kernels, fixup Docker

* ipex fix

* Remove outdated TODO
2025-02-10 19:19:25 +01:00

590 lines
21 KiB
Python

from dataclasses import dataclass
import os
from typing import Optional, Tuple, Type, Union, List
import torch
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from text_generation_server.utils.log import log_once
if SYSTEM == "cuda":
quantization = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
else:
quantization = None
try:
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
except ImportError:
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and quantization is not None:
major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(
major * 10 + minor
)
else:
CUTLASS_FP8_AVAILABLE = False
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
# Marlin is W8A16, use it when:
#
# - On capability 8.x where x < 8: W8A8 FP8 GEMM is not supported.
# - On capability 8.9: W8A8 FP8 GEMM is supported, but Marlin-FP8 is faster.
# - On capability 9.x when force_w8a16: cutlass kernels do not support W8A16.
if (major == 8 or (major == 9 and force_w8a16)) and os.getenv(
"USE_CUTLASS_W8A8", "0"
) != "1":
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
# gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
if major == 8 and minor == 9:
log_once(
logger.info,
"GPU supports FP8, but using Marlin FP8 kernel for better performance",
)
else:
log_once(
logger.info, "GPU does not support FP8, using Marlin FP8 kernel"
)
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max().float()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if quantization is not None:
shape = weight.shape
qweight, scale = quantization.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
return qweight.reshape(shape), scale
finfo = torch.finfo(qdtype)
if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
if SYSTEM == "rocm":
scale = scale / 2.0
# Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)
return qweight, scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
w, scale, input_scale
)
if scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
)
class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__(
self,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
weight=qweight, weight_scale=scale, input_scale=input_scale
)
self.dtype = dtype
self.qweight = qweight
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor(
scale_upper_bound, dtype=torch.float32, device=qweight.device
)
else:
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
return cls(
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
)
@classmethod
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
return cls(
qweight=weight,
scale=scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
weight_block_size=weight_block_size,
)
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound, scalar=False
)
return quantization.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias
output = output.to(dtype=self.dtype)
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
elif SYSTEM == "rocm":
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0])