mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix scaling
This commit is contained in:
parent
988c1dc622
commit
bffccdd640
@ -63,35 +63,40 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert weight.dtype == torch.float8_e4m3fn
|
||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||
# https://onnx.ai/onnx/technical/float8.html
|
||||
weight_as_int8 = weight.view(torch.int8)
|
||||
ROCM_FP8_NAN_AS_INT = -128
|
||||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||
if weight.dtype == torch.float8_e4m3fn:
|
||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||
# https://onnx.ai/onnx/technical/float8.html
|
||||
weight_as_int8 = weight.view(torch.int8)
|
||||
ROCM_FP8_NAN_AS_INT = -128
|
||||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||
|
||||
# For the same bits representation, e4m3fnuz value is half of
|
||||
# the e4m3fn value, so we should double the scaling factor to
|
||||
# get the same dequantized value.
|
||||
# https://onnx.ai/onnx/technical/float8.html
|
||||
weight_scale = weight_scale * 2.0
|
||||
if input_scale is not None:
|
||||
input_scale = input_scale * 2.0
|
||||
# For the same bits representation, e4m3fnuz value is half of
|
||||
# the e4m3fn value, so we should double the scaling factor to
|
||||
# get the same dequantized value.
|
||||
# https://onnx.ai/onnx/technical/float8.html
|
||||
weight_scale = weight_scale * 2.0
|
||||
if input_scale is not None:
|
||||
input_scale = input_scale * 2.0
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
||||
tensor: torch.Tensor,
|
||||
inv_scale: Union[float, torch.Tensor],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
fake_qweight = tensor.to(dtype)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
|
||||
def requantize_with_max_scale(
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max().float()
|
||||
@ -99,7 +104,9 @@ def requantize_with_max_scale(
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||
weight_dq = per_tensor_dequantize(
|
||||
weight[start:end, :], weight_scale[idx], dtype
|
||||
)
|
||||
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||
weight_dq, max_w_scale
|
||||
)
|
||||
@ -112,7 +119,7 @@ def fp8_quantize(
|
||||
weight: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||
qdtype: torch.dtype = quant_dtype,
|
||||
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||
scalar: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -125,7 +132,7 @@ def fp8_quantize(
|
||||
shape = weight.shape
|
||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||
weight.reshape(-1, shape[-1]),
|
||||
dtype=qdtype,
|
||||
dtype=quant_dtype,
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
# TODO: don't do this when we have to use the Torch kernel.
|
||||
@ -145,6 +152,8 @@ def fp8_quantize(
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
scale = scale.float().reciprocal()
|
||||
else:
|
||||
if SYSTEM == "rocm":
|
||||
scale = scale / 2.0
|
||||
# Use reciprocal to avoid more expensive division.
|
||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||
|
||||
@ -263,12 +272,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
if scale.numel() == len(prefixes):
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths
|
||||
)
|
||||
|
||||
input_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
@ -281,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
else None
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
w, scale, input_scale
|
||||
)
|
||||
|
||||
if scale.numel() == len(prefixes):
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -366,7 +380,7 @@ class Fp8Linear(torch.nn.Module):
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=qweight, weight_scale=scale
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user