mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Always use dynamic input quantization for w8a8 int
It's far less flaky and gives better output.
This commit is contained in:
parent
b2dc10aea5
commit
f76c0ff17f
@ -37,22 +37,14 @@ class W8A8IntLoader(WeightsLoader):
|
||||
self.load_weight_scale = not weight_args.dynamic
|
||||
|
||||
if input_args is not None:
|
||||
static = not input_args.dynamic
|
||||
symmetric = input_args.symmetric
|
||||
self.load_input_scale = static
|
||||
self.load_input_zero_point = static and not symmetric
|
||||
self.input_symmetric = input_args.symmetric
|
||||
|
||||
if static:
|
||||
# People shouldn't really use static input quantization,
|
||||
# the output is pretty bad.
|
||||
if not input_args.dynamic:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Using W8A8 int with static input quantization results in large regressions in accuracy. Consider dynamic input quantization instead.",
|
||||
"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).",
|
||||
)
|
||||
else:
|
||||
self.load_input_scale = False
|
||||
self.load_input_zero_point = False
|
||||
self.input_symmetric = True
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -62,7 +54,7 @@ class W8A8IntLoader(WeightsLoader):
|
||||
def symmetric_to_sting(symmetric):
|
||||
return "symmetric" if symmetric else "asymmetric"
|
||||
|
||||
return f"{self.__class__.__name__} (w8a8 int, input: {symmetric_to_sting(self.input_symmetric)}, {scale_to_str(self.load_input_scale)})"
|
||||
return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_sting(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))"
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
|
||||
@ -73,23 +65,7 @@ class W8A8IntLoader(WeightsLoader):
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
input_zero_point = None
|
||||
if self.load_input_zero_point:
|
||||
input_zero_point = _get_tensor_or_else(
|
||||
weights,
|
||||
f"{prefix}.input_zero_point",
|
||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_scale=input_scale,
|
||||
input_zero_point=input_zero_point,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
@ -117,23 +93,7 @@ class W8A8IntLoader(WeightsLoader):
|
||||
)
|
||||
weight_scale = weight_scale.reshape(-1)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
input_zero_point = None
|
||||
if self.load_input_zero_point:
|
||||
input_zero_point = _get_tensor_or_else(
|
||||
weights,
|
||||
f"{prefix}.input_zero_point",
|
||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_scale=input_scale,
|
||||
input_zero_point=input_zero_point,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
@ -155,28 +115,7 @@ class W8A8IntLoader(WeightsLoader):
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = [
|
||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False) for p in prefixes
|
||||
]
|
||||
input_scale = torch.cat(input_scale, dim=0)
|
||||
|
||||
input_zero_point = None
|
||||
if self.load_input_zero_point:
|
||||
input_zero_point = [
|
||||
_get_tensor_or_else(
|
||||
weights,
|
||||
f"{prefix}.input_zero_point",
|
||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
||||
)
|
||||
for prefix in prefixes
|
||||
]
|
||||
input_zero_point = torch.cat(input_zero_point, dim=0)
|
||||
|
||||
return Int8Weight(
|
||||
input_scale=input_scale,
|
||||
input_zero_point=input_zero_point,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
@ -191,23 +130,7 @@ class W8A8IntLoader(WeightsLoader):
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
input_zero_point = None
|
||||
if self.load_input_zero_point:
|
||||
input_zero_point = _get_tensor_or_else(
|
||||
weights,
|
||||
f"{prefix}.input_zero_point",
|
||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_scale=input_scale,
|
||||
input_zero_point=input_zero_point,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
@ -230,8 +153,6 @@ def _get_tensor_or_else(
|
||||
|
||||
@dataclass
|
||||
class Int8Weight(Weight):
|
||||
input_scale: Optional[torch.Tensor]
|
||||
input_zero_point: Optional[torch.Tensor]
|
||||
input_symmetric: bool
|
||||
weight: torch.Tensor
|
||||
weight_scale: Optional[torch.Tensor]
|
||||
@ -242,8 +163,6 @@ class Int8Weight(Weight):
|
||||
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_scale=self.input_scale,
|
||||
input_zero_point=self.input_zero_point,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=qweight,
|
||||
weight_scale=weight_scale,
|
||||
@ -251,8 +170,6 @@ class Int8Weight(Weight):
|
||||
else:
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_scale=self.input_scale,
|
||||
input_zero_point=self.input_zero_point,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
@ -264,17 +181,12 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
self,
|
||||
*,
|
||||
bias: Optional[torch.Tensor],
|
||||
input_scale: Optional[torch.Tensor],
|
||||
input_zero_point: Optional[torch.Tensor],
|
||||
input_symmetric: bool,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
input_scale = (
|
||||
input_scale.to(torch.float32) if input_scale is not None else input_scale
|
||||
)
|
||||
weight_scale = weight_scale.to(torch.float32)
|
||||
|
||||
self.bias = bias
|
||||
@ -283,35 +195,6 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
self.weight = weight.t()
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
if input_scale is not None:
|
||||
if input_zero_point is None:
|
||||
# Symmetric: simply use the largest scale to cover fused layers.
|
||||
input_scale = input_scale.max()
|
||||
else:
|
||||
# Asymmetric: find the range that contains all individual ranges.
|
||||
input_zero_point = input_zero_point.to(torch.int32)
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
|
||||
# Find the most extreme values of all zero point/input scale
|
||||
# pairs.
|
||||
range_min = (input_scale * (int8_info.min - input_zero_point)).min()
|
||||
range_max = (input_scale * (int8_info.max - input_zero_point)).max()
|
||||
|
||||
# Calculate new scale and zero point.
|
||||
input_scale = (range_max - range_min) / (int8_info.max - int8_info.min)
|
||||
input_zero_point = int8_info.min - (range_min / input_scale)
|
||||
input_zero_point = input_zero_point.to(torch.int32)
|
||||
|
||||
self.range_min = (
|
||||
input_scale * (int8_info.min - input_zero_point)
|
||||
).min()
|
||||
self.range_max = (
|
||||
input_scale * (int8_info.max - input_zero_point)
|
||||
).max()
|
||||
|
||||
self.input_scale = input_scale
|
||||
self.input_zero_point = input_zero_point
|
||||
|
||||
if input_symmetric:
|
||||
self.zero_point_adj = None
|
||||
else:
|
||||
@ -320,16 +203,13 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
dim=0, keepdim=True, dtype=torch.int32
|
||||
)
|
||||
|
||||
if input_zero_point is not None:
|
||||
self.zero_point_adj *= input_zero_point
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
|
||||
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
|
||||
input=input,
|
||||
scale=self.input_scale,
|
||||
azp=self.input_zero_point,
|
||||
scale=None,
|
||||
azp=None,
|
||||
symmetric=self.input_symmetric,
|
||||
)
|
||||
|
||||
@ -343,7 +223,11 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
bias=self.bias,
|
||||
)
|
||||
else:
|
||||
assert self.zero_point_adj is not None and input_scale is not None
|
||||
assert (
|
||||
self.zero_point_adj is not None
|
||||
and input_scale is not None
|
||||
and (self.input_symmetric or input_zero_point is not None)
|
||||
)
|
||||
|
||||
return marlin_kernels.cutlass_scaled_mm_azp(
|
||||
a=qinput,
|
||||
@ -352,8 +236,6 @@ class W8A8IntLinear(torch.nn.Module):
|
||||
scale_b=self.weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
azp_adj=self.zero_point_adj,
|
||||
# Zero point is already in the adjustment when using static
|
||||
# input quantization.
|
||||
azp=input_zero_point if self.input_zero_point is None else None,
|
||||
azp=input_zero_point,
|
||||
bias=self.bias,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user