diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index d7fb64ba..4e83ec9d 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -63,35 +63,40 @@ def normalize_e4m3fn_to_e4m3fnuz( weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - assert weight.dtype == torch.float8_e4m3fn - # The bits pattern 10000000(-128) represents zero in e4m3fn - # but NaN in e4m3fnuz. So here we set it to 0. - # https://onnx.ai/onnx/technical/float8.html - weight_as_int8 = weight.view(torch.int8) - ROCM_FP8_NAN_AS_INT = -128 - weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 - weight = weight_as_int8.view(torch.float8_e4m3fnuz) + if weight.dtype == torch.float8_e4m3fn: + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) - # For the same bits representation, e4m3fnuz value is half of - # the e4m3fn value, so we should double the scaling factor to - # get the same dequantized value. - # https://onnx.ai/onnx/technical/float8.html - weight_scale = weight_scale * 2.0 - if input_scale is not None: - input_scale = input_scale * 2.0 + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 return weight, weight_scale, input_scale def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] + tensor: torch.Tensor, + inv_scale: Union[float, torch.Tensor], + dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) + fake_qweight = tensor.to(dtype) dq_weight = fake_qweight * inv_scale return dq_weight def requantize_with_max_scale( - weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: int, + dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max().float() @@ -99,7 +104,9 @@ def requantize_with_max_scale( start = 0 for idx, logical_width in enumerate(logical_widths): end = start + logical_width - weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) + weight_dq = per_tensor_dequantize( + weight[start:end, :], weight_scale[idx], dtype + ) weight[start:end, :], max_w_scale_normalized = fp8_quantize( weight_dq, max_w_scale ) @@ -112,7 +119,7 @@ def fp8_quantize( weight: torch.Tensor, scale: Optional[torch.Tensor] = None, scale_upper_bound: Optional[torch.Tensor] = None, - qdtype: torch.dtype = quant_dtype, + qdtype: torch.dtype = torch.float8_e4m3fn, scalar: bool = False, ): """ @@ -125,7 +132,7 @@ def fp8_quantize( shape = weight.shape qweight, scale = marlin_kernels.scaled_fp8_quant( weight.reshape(-1, shape[-1]), - dtype=qdtype, + dtype=quant_dtype, scale=scale, scale_ub=scale_upper_bound, # TODO: don't do this when we have to use the Torch kernel. @@ -145,6 +152,8 @@ def fp8_quantize( qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) scale = scale.float().reciprocal() else: + if SYSTEM == "rocm": + scale = scale / 2.0 # Use reciprocal to avoid more expensive division. qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max) @@ -263,12 +272,6 @@ class HybridFP8UnquantLoader(WeightsLoader): ] scale = torch.cat(scale, dim=0).reshape(-1) - if scale.numel() == len(prefixes): - logical_widths = [x[0] for x in shapes] - w, scale = requantize_with_max_scale( - w, scale.to(weights.device), logical_widths - ) - input_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) for p, shape in zip(prefixes, shapes) @@ -281,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) + if SYSTEM == "rocm": + w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, scale, input_scale + ) + + if scale.numel() == len(prefixes): + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.to(weights.device), logical_widths, weights.dtype + ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -366,7 +380,7 @@ class Fp8Linear(torch.nn.Module): if CUTLASS_FP8_AVAILABLE: log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: - qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( + qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale )