From e07acc7f68c6271ce675010ec607cca9449635c0 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 15 Jan 2025 11:38:58 +0530 Subject: [PATCH] Enable FP8 Per-Tensor Scales and Integrate Marlin/MoE Kernels Repo for ROCm (#2825) * (feat) convert tscales to tensorwise * (fix) fp8 scaling for cuda * (kernel) add marlin-kernels * add moe-kernels * fix moe kernel comit * fix scaling * nm changes --- Dockerfile_amd | 12 ++ .../layers/attention/rocm.py | 6 +- .../layers/compressed_tensors/w8an_fp.py | 40 ++++-- server/text_generation_server/layers/fp8.py | 125 +++++++++++++----- 4 files changed, 134 insertions(+), 49 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index dc748f490..1f34ffa30 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -268,6 +268,15 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build +FROM kernel-builder AS marlin-kernels +WORKDIR /usr/src +ENV MARLIN_KERNELS_BRANCH=v0.3.6 +ENV VLLM_TARGET_DEVICE=rocm +RUN git clone https://github.com/danieldk/marlin-kernels.git && \ + cd marlin-kernels && \ + git checkout ${MARLIN_KERNELS_BRANCH} && \ + python setup.py install + FROM kernel-builder AS moe-kernels WORKDIR /usr/src ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd @@ -299,6 +308,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 # Copy build artifacts from exllamav2 kernels builder COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from marlin kernels +COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages + # Copy build artifacts from moe kernels COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69a245ad5..b94b737dc 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -163,17 +163,17 @@ def paged_attention( else: # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( + tmp_output = torch.zeros( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=out.dtype, device=out.device, ) - exp_sums = torch.empty( + exp_sums = torch.zeros( size=(num_seqs, num_heads, max_num_partitions), dtype=torch.float32, device=out.device, ) - max_logits = torch.empty_like(exp_sums) + max_logits = torch.zeros_like(exp_sums) if not use_custom: ops.paged_attention_v2( diff --git a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py index e63c52123..15bdce08e 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -3,8 +3,14 @@ from typing import List, Optional, Union import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationType -from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, + requantize_with_max_scale, + normalize_e4m3fn_to_e4m3fnuz, +) from text_generation_server.utils.weights import Weights, WeightsLoader +from text_generation_server.utils.import_utils import SYSTEM class W8ANFpLoader(WeightsLoader): @@ -47,11 +53,10 @@ class W8ANFpLoader(WeightsLoader): weight_scale = None if self.load_weight_scale: - weight_scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: @@ -87,7 +92,8 @@ class W8ANFpLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: @@ -141,6 +147,17 @@ class W8ANFpLoader(WeightsLoader): else None ) + if self.load_weight_scale or SYSTEM == "rocm": + w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, weight_scale, input_scale + ) + + if weight_scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, weight_scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=weight_scale, @@ -153,11 +170,10 @@ class W8ANFpLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) weight_scale = None if self.load_weight_scale: - weight_scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 1e5c8b3d6..4e83ec9d0 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -19,6 +19,9 @@ try: except ImportError: marlin_kernels = None +quant_dtype: torch.dtype = ( + torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn +) if SYSTEM == "cuda" and marlin_kernels is not None: major, minor = torch.cuda.get_device_capability() @@ -60,25 +63,58 @@ def normalize_e4m3fn_to_e4m3fnuz( weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - assert weight.dtype == torch.float8_e4m3fn - # The bits pattern 10000000(-128) represents zero in e4m3fn - # but NaN in e4m3fnuz. So here we set it to 0. - # https://onnx.ai/onnx/technical/float8.html - weight_as_int8 = weight.view(torch.int8) - ROCM_FP8_NAN_AS_INT = -128 - weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 - weight = weight_as_int8.view(torch.float8_e4m3fnuz) + if weight.dtype == torch.float8_e4m3fn: + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) - # For the same bits representation, e4m3fnuz value is half of - # the e4m3fn value, so we should double the scaling factor to - # get the same dequantized value. - # https://onnx.ai/onnx/technical/float8.html - weight_scale = weight_scale * 2.0 - if input_scale is not None: - input_scale = input_scale * 2.0 + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 return weight, weight_scale, input_scale +def per_tensor_dequantize( + tensor: torch.Tensor, + inv_scale: Union[float, torch.Tensor], + dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + fake_qweight = tensor.to(dtype) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def requantize_with_max_scale( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: int, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max().float() + + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx], dtype + ) + weight[start:end, :], max_w_scale_normalized = fp8_quantize( + weight_dq, max_w_scale + ) + start = end + + return weight, max_w_scale_normalized + + def fp8_quantize( weight: torch.Tensor, scale: Optional[torch.Tensor] = None, @@ -96,7 +132,7 @@ def fp8_quantize( shape = weight.shape qweight, scale = marlin_kernels.scaled_fp8_quant( weight.reshape(-1, shape[-1]), - dtype=qdtype, + dtype=quant_dtype, scale=scale, scale_ub=scale_upper_bound, # TODO: don't do this when we have to use the Torch kernel. @@ -116,6 +152,8 @@ def fp8_quantize( qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) scale = scale.float().reciprocal() else: + if SYSTEM == "rocm": + scale = scale / 2.0 # Use reciprocal to avoid more expensive division. qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) @@ -141,17 +179,18 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + scale.reshape(-1).expand(w.shape[0]) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) return Fp8Weight( weight=w, @@ -178,6 +217,7 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: scale = weights.get_packed_sharded( f"{prefix}.weight_scale", @@ -185,7 +225,8 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]) + if SYSTEM == "cuda": + scale = scale.reshape(-1).expand(w.shape[0]) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -243,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) + if SYSTEM == "rocm": + w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, scale, input_scale + ) + + if scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -259,16 +311,18 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + scale = scale.reshape(-1).expand(w.shape[0]) + input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) + input_scale = ( + weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + .reshape(-1) + .max() + ) return Fp8Weight( weight=w, @@ -326,7 +380,7 @@ class Fp8Linear(torch.nn.Module): if CUTLASS_FP8_AVAILABLE: log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: - qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( + qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale ) @@ -443,6 +497,9 @@ class Fp8Linear(torch.nn.Module): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + elif SYSTEM == "rocm": + return scale.reshape(-1) return scale.reshape(-1).expand(shape[0])