mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix scaling
This commit is contained in:
parent
988c1dc622
commit
bffccdd640
@ -63,35 +63,40 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert weight.dtype == torch.float8_e4m3fn
|
if weight.dtype == torch.float8_e4m3fn:
|
||||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
# https://onnx.ai/onnx/technical/float8.html
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
weight_as_int8 = weight.view(torch.int8)
|
weight_as_int8 = weight.view(torch.int8)
|
||||||
ROCM_FP8_NAN_AS_INT = -128
|
ROCM_FP8_NAN_AS_INT = -128
|
||||||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
# For the same bits representation, e4m3fnuz value is half of
|
# For the same bits representation, e4m3fnuz value is half of
|
||||||
# the e4m3fn value, so we should double the scaling factor to
|
# the e4m3fn value, so we should double the scaling factor to
|
||||||
# get the same dequantized value.
|
# get the same dequantized value.
|
||||||
# https://onnx.ai/onnx/technical/float8.html
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
weight_scale = weight_scale * 2.0
|
weight_scale = weight_scale * 2.0
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
input_scale = input_scale * 2.0
|
input_scale = input_scale * 2.0
|
||||||
return weight, weight_scale, input_scale
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_dequantize(
|
def per_tensor_dequantize(
|
||||||
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
tensor: torch.Tensor,
|
||||||
|
inv_scale: Union[float, torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
fake_qweight = tensor.to(torch.float16)
|
fake_qweight = tensor.to(dtype)
|
||||||
dq_weight = fake_qweight * inv_scale
|
dq_weight = fake_qweight * inv_scale
|
||||||
return dq_weight
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
def requantize_with_max_scale(
|
def requantize_with_max_scale(
|
||||||
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Max scale to be used for requanitzation.
|
# Max scale to be used for requanitzation.
|
||||||
max_w_scale = weight_scale.max().float()
|
max_w_scale = weight_scale.max().float()
|
||||||
@ -99,7 +104,9 @@ def requantize_with_max_scale(
|
|||||||
start = 0
|
start = 0
|
||||||
for idx, logical_width in enumerate(logical_widths):
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
end = start + logical_width
|
end = start + logical_width
|
||||||
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
weight_dq = per_tensor_dequantize(
|
||||||
|
weight[start:end, :], weight_scale[idx], dtype
|
||||||
|
)
|
||||||
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
weight_dq, max_w_scale
|
weight_dq, max_w_scale
|
||||||
)
|
)
|
||||||
@ -112,7 +119,7 @@ def fp8_quantize(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
scale_upper_bound: Optional[torch.Tensor] = None,
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
qdtype: torch.dtype = quant_dtype,
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
scalar: bool = False,
|
scalar: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -125,7 +132,7 @@ def fp8_quantize(
|
|||||||
shape = weight.shape
|
shape = weight.shape
|
||||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||||
weight.reshape(-1, shape[-1]),
|
weight.reshape(-1, shape[-1]),
|
||||||
dtype=qdtype,
|
dtype=quant_dtype,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_ub=scale_upper_bound,
|
scale_ub=scale_upper_bound,
|
||||||
# TODO: don't do this when we have to use the Torch kernel.
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
@ -145,6 +152,8 @@ def fp8_quantize(
|
|||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
scale = scale.float().reciprocal()
|
scale = scale.float().reciprocal()
|
||||||
else:
|
else:
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
scale = scale / 2.0
|
||||||
# Use reciprocal to avoid more expensive division.
|
# Use reciprocal to avoid more expensive division.
|
||||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
|
||||||
@ -263,12 +272,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
if scale.numel() == len(prefixes):
|
|
||||||
logical_widths = [x[0] for x in shapes]
|
|
||||||
w, scale = requantize_with_max_scale(
|
|
||||||
w, scale.to(weights.device), logical_widths
|
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = [
|
input_scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
@ -281,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
w, scale, input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -366,7 +380,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=qweight, weight_scale=scale
|
weight=qweight, weight_scale=scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user