Enable FP8 Per-Tensor Scales and Integrate Marlin/MoE Kernels Repo for ROCm (#2825)

* (feat) convert tscales to tensorwise

* (fix) fp8 scaling for cuda

* (kernel) add marlin-kernels

* add moe-kernels

* fix moe kernel comit

* fix scaling

* nm changes
This commit is contained in:
Mohit Sharma 2025-01-15 11:38:58 +05:30 committed by GitHub
parent 880ab9c2f3
commit e07acc7f68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 49 deletions

View File

@ -268,6 +268,15 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build
FROM kernel-builder AS marlin-kernels
WORKDIR /usr/src
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python setup.py install
FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
@ -299,6 +308,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from marlin kernels
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from moe kernels
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages

View File

@ -163,17 +163,17 @@ def paged_attention(
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
tmp_output = torch.zeros(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
exp_sums = torch.zeros(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
max_logits = torch.zeros_like(exp_sums)
if not use_custom:
ops.paged_attention_v2(

View File

@ -3,8 +3,14 @@ from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
normalize_e4m3fn_to_e4m3fnuz,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM
class W8ANFpLoader(WeightsLoader):
@ -47,11 +53,10 @@ class W8ANFpLoader(WeightsLoader):
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
input_scale = None
if self.load_input_scale:
@ -87,7 +92,8 @@ class W8ANFpLoader(WeightsLoader):
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
if SYSTEM == "cuda":
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
input_scale = None
if self.load_input_scale:
@ -141,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
else None
)
if self.load_weight_scale or SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, weight_scale, input_scale
)
if weight_scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w, weight_scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
@ -153,11 +170,10 @@ class W8ANFpLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
input_scale = None
if self.load_input_scale:

View File

@ -19,6 +19,9 @@ try:
except ImportError:
marlin_kernels = None
quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and marlin_kernels is not None:
major, minor = torch.cuda.get_device_capability()
@ -60,25 +63,58 @@ def normalize_e4m3fn_to_e4m3fnuz(
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
if weight.dtype == torch.float8_e4m3fn:
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max().float()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
@ -96,7 +132,7 @@ def fp8_quantize(
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
dtype=quant_dtype,
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
@ -116,6 +152,8 @@ def fp8_quantize(
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
if SYSTEM == "rocm":
scale = scale / 2.0
# Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
@ -141,17 +179,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
@ -178,6 +217,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
@ -185,7 +225,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -243,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None
)
if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, scale, input_scale
)
if scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -259,16 +311,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
@ -326,7 +380,7 @@ class Fp8Linear(torch.nn.Module):
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)
@ -443,6 +497,9 @@ class Fp8Linear(torch.nn.Module):
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
elif SYSTEM == "rocm":
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0])