Always use dynamic input quantization for w8a8 int

It's far less flaky and gives better output.
This commit is contained in:
Daniël de Kok 2024-11-18 10:54:51 +00:00
parent b2dc10aea5
commit f76c0ff17f

View File

@ -37,22 +37,14 @@ class W8A8IntLoader(WeightsLoader):
self.load_weight_scale = not weight_args.dynamic
if input_args is not None:
static = not input_args.dynamic
symmetric = input_args.symmetric
self.load_input_scale = static
self.load_input_zero_point = static and not symmetric
self.input_symmetric = input_args.symmetric
if static:
# People shouldn't really use static input quantization,
# the output is pretty bad.
if not input_args.dynamic:
log_once(
logger.warning,
"Using W8A8 int with static input quantization results in large regressions in accuracy. Consider dynamic input quantization instead.",
"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).",
)
else:
self.load_input_scale = False
self.load_input_zero_point = False
self.input_symmetric = True
def __str__(self) -> str:
@ -62,7 +54,7 @@ class W8A8IntLoader(WeightsLoader):
def symmetric_to_sting(symmetric):
return "symmetric" if symmetric else "asymmetric"
return f"{self.__class__.__name__} (w8a8 int, input: {symmetric_to_sting(self.input_symmetric)}, {scale_to_str(self.load_input_scale)})"
return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_sting(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
@ -73,23 +65,7 @@ class W8A8IntLoader(WeightsLoader):
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = _get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
).reshape(-1)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
@ -117,23 +93,7 @@ class W8A8IntLoader(WeightsLoader):
)
weight_scale = weight_scale.reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = _get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
).reshape(-1)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
@ -155,28 +115,7 @@ class W8A8IntLoader(WeightsLoader):
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False) for p in prefixes
]
input_scale = torch.cat(input_scale, dim=0)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = [
_get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
)
for prefix in prefixes
]
input_zero_point = torch.cat(input_zero_point, dim=0)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
@ -191,23 +130,7 @@ class W8A8IntLoader(WeightsLoader):
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
input_zero_point = None
if self.load_input_zero_point:
input_zero_point = _get_tensor_or_else(
weights,
f"{prefix}.input_zero_point",
torch.zeros((1,), device=w.device, dtype=torch.int8),
).reshape(-1)
return Int8Weight(
input_scale=input_scale,
input_zero_point=input_zero_point,
input_symmetric=self.input_symmetric,
weight=w,
weight_scale=weight_scale,
@ -230,8 +153,6 @@ def _get_tensor_or_else(
@dataclass
class Int8Weight(Weight):
input_scale: Optional[torch.Tensor]
input_zero_point: Optional[torch.Tensor]
input_symmetric: bool
weight: torch.Tensor
weight_scale: Optional[torch.Tensor]
@ -242,8 +163,6 @@ class Int8Weight(Weight):
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
return W8A8IntLinear(
bias=bias,
input_scale=self.input_scale,
input_zero_point=self.input_zero_point,
input_symmetric=self.input_symmetric,
weight=qweight,
weight_scale=weight_scale,
@ -251,8 +170,6 @@ class Int8Weight(Weight):
else:
return W8A8IntLinear(
bias=bias,
input_scale=self.input_scale,
input_zero_point=self.input_zero_point,
input_symmetric=self.input_symmetric,
weight=self.weight,
weight_scale=self.weight_scale,
@ -264,17 +181,12 @@ class W8A8IntLinear(torch.nn.Module):
self,
*,
bias: Optional[torch.Tensor],
input_scale: Optional[torch.Tensor],
input_zero_point: Optional[torch.Tensor],
input_symmetric: bool,
weight: torch.Tensor,
weight_scale: torch.Tensor,
):
super().__init__()
input_scale = (
input_scale.to(torch.float32) if input_scale is not None else input_scale
)
weight_scale = weight_scale.to(torch.float32)
self.bias = bias
@ -283,35 +195,6 @@ class W8A8IntLinear(torch.nn.Module):
self.weight = weight.t()
self.weight_scale = weight_scale
if input_scale is not None:
if input_zero_point is None:
# Symmetric: simply use the largest scale to cover fused layers.
input_scale = input_scale.max()
else:
# Asymmetric: find the range that contains all individual ranges.
input_zero_point = input_zero_point.to(torch.int32)
int8_info = torch.iinfo(torch.int8)
# Find the most extreme values of all zero point/input scale
# pairs.
range_min = (input_scale * (int8_info.min - input_zero_point)).min()
range_max = (input_scale * (int8_info.max - input_zero_point)).max()
# Calculate new scale and zero point.
input_scale = (range_max - range_min) / (int8_info.max - int8_info.min)
input_zero_point = int8_info.min - (range_min / input_scale)
input_zero_point = input_zero_point.to(torch.int32)
self.range_min = (
input_scale * (int8_info.min - input_zero_point)
).min()
self.range_max = (
input_scale * (int8_info.max - input_zero_point)
).max()
self.input_scale = input_scale
self.input_zero_point = input_zero_point
if input_symmetric:
self.zero_point_adj = None
else:
@ -320,16 +203,13 @@ class W8A8IntLinear(torch.nn.Module):
dim=0, keepdim=True, dtype=torch.int32
)
if input_zero_point is not None:
self.zero_point_adj *= input_zero_point
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
input=input,
scale=self.input_scale,
azp=self.input_zero_point,
scale=None,
azp=None,
symmetric=self.input_symmetric,
)
@ -343,7 +223,11 @@ class W8A8IntLinear(torch.nn.Module):
bias=self.bias,
)
else:
assert self.zero_point_adj is not None and input_scale is not None
assert (
self.zero_point_adj is not None
and input_scale is not None
and (self.input_symmetric or input_zero_point is not None)
)
return marlin_kernels.cutlass_scaled_mm_azp(
a=qinput,
@ -352,8 +236,6 @@ class W8A8IntLinear(torch.nn.Module):
scale_b=self.weight_scale,
out_dtype=input.dtype,
azp_adj=self.zero_point_adj,
# Zero point is already in the adjustment when using static
# input quantization.
azp=input_zero_point if self.input_zero_point is None else None,
azp=input_zero_point,
bias=self.bias,
)