From e22cb47fe3b88eaecc024286812f5b432e021f4d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 6 Dec 2024 14:16:39 +0000 Subject: [PATCH] (fix) fp8 scaling for cuda --- .../layers/attention/rocm.py | 6 +-- .../layers/compressed_tensors/w8an_fp.py | 34 +++++++++++------ server/text_generation_server/layers/fp8.py | 38 ++++++++++--------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index ea11c2c2..a500ca6a 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -116,17 +116,17 @@ def paged_attention( else: # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( + tmp_output = torch.zeros( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=out.dtype, device=out.device, ) - exp_sums = torch.empty( + exp_sums = torch.zeros( size=(num_seqs, num_heads, max_num_partitions), dtype=torch.float32, device=out.device, ) - max_logits = torch.empty_like(exp_sums) + max_logits = torch.zeros_like(exp_sums) if not use_custom: ops.paged_attention_v2( diff --git a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py index e63c5212..959fd5b3 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -3,8 +3,13 @@ from typing import List, Optional, Union import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationType -from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, + requantize_with_max_scale, +) from text_generation_server.utils.weights import Weights, WeightsLoader +from text_generation_server.utils.import_utils import SYSTEM class W8ANFpLoader(WeightsLoader): @@ -47,11 +52,10 @@ class W8ANFpLoader(WeightsLoader): weight_scale = None if self.load_weight_scale: - weight_scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: @@ -87,7 +91,8 @@ class W8ANFpLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: @@ -127,6 +132,12 @@ class W8ANFpLoader(WeightsLoader): ] weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + if weight_scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, weight_scale.to(weights.device), logical_widths + ) + input_scale = None if self.load_input_scale: input_scale = [ @@ -153,11 +164,10 @@ class W8ANFpLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) weight_scale = None if self.load_weight_scale: - weight_scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ) + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) input_scale = None if self.load_input_scale: diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 03e3660e..32b6cdd6 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -167,11 +167,10 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ).max() + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + scale.reshape(-1).expand(w.shape[0]) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -206,6 +205,7 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if scale.numel() > 1: scale = weights.get_packed_sharded( f"{prefix}.weight_scale", @@ -213,7 +213,8 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) - scale = scale.reshape(-1).expand(w.shape[0]).max() + if SYSTEM == "cuda": + scale = scale.reshape(-1).expand(w.shape[0]) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -255,15 +256,15 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) - .max() - .unsqueeze(0) for p, shape in zip(prefixes, shapes) ] - scale = torch.cat(scale).to(weights.device) + scale = torch.cat(scale, dim=0).reshape(-1) - logical_widths = [x[0] for x in shapes] - - w, scale = requantize_with_max_scale(w, scale, logical_widths) + if scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths + ) input_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) @@ -293,11 +294,11 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = ( - weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) - .reshape(-1) - .expand(w.shape[0]) - ).max() + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + + if SYSTEM == "cuda": + scale = scale.reshape(-1).expand(w.shape[0]) + input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): input_scale = ( @@ -479,6 +480,9 @@ class Fp8Linear(torch.nn.Module): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) + if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) + elif SYSTEM == "rocm": + return scale.reshape(-1) return scale.reshape(-1).expand(shape[0])