mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Always use dynamic input quantization for w8a8 int
It's far less flaky and gives better output.
This commit is contained in:
parent
b2dc10aea5
commit
f76c0ff17f
@ -37,22 +37,14 @@ class W8A8IntLoader(WeightsLoader):
|
|||||||
self.load_weight_scale = not weight_args.dynamic
|
self.load_weight_scale = not weight_args.dynamic
|
||||||
|
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
static = not input_args.dynamic
|
|
||||||
symmetric = input_args.symmetric
|
|
||||||
self.load_input_scale = static
|
|
||||||
self.load_input_zero_point = static and not symmetric
|
|
||||||
self.input_symmetric = input_args.symmetric
|
self.input_symmetric = input_args.symmetric
|
||||||
|
|
||||||
if static:
|
if not input_args.dynamic:
|
||||||
# People shouldn't really use static input quantization,
|
|
||||||
# the output is pretty bad.
|
|
||||||
log_once(
|
log_once(
|
||||||
logger.warning,
|
logger.warning,
|
||||||
"Using W8A8 int with static input quantization results in large regressions in accuracy. Consider dynamic input quantization instead.",
|
"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.load_input_scale = False
|
|
||||||
self.load_input_zero_point = False
|
|
||||||
self.input_symmetric = True
|
self.input_symmetric = True
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -62,7 +54,7 @@ class W8A8IntLoader(WeightsLoader):
|
|||||||
def symmetric_to_sting(symmetric):
|
def symmetric_to_sting(symmetric):
|
||||||
return "symmetric" if symmetric else "asymmetric"
|
return "symmetric" if symmetric else "asymmetric"
|
||||||
|
|
||||||
return f"{self.__class__.__name__} (w8a8 int, input: {symmetric_to_sting(self.input_symmetric)}, {scale_to_str(self.load_input_scale)})"
|
return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_sting(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))"
|
||||||
|
|
||||||
def get_weights(self, weights: "Weights", prefix: str):
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
|
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
|
||||||
@ -73,23 +65,7 @@ class W8A8IntLoader(WeightsLoader):
|
|||||||
f"{prefix}.weight_scale", to_dtype=False
|
f"{prefix}.weight_scale", to_dtype=False
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
input_scale = None
|
|
||||||
if self.load_input_scale:
|
|
||||||
input_scale = weights.get_tensor(
|
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
input_zero_point = None
|
|
||||||
if self.load_input_zero_point:
|
|
||||||
input_zero_point = _get_tensor_or_else(
|
|
||||||
weights,
|
|
||||||
f"{prefix}.input_zero_point",
|
|
||||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
return Int8Weight(
|
return Int8Weight(
|
||||||
input_scale=input_scale,
|
|
||||||
input_zero_point=input_zero_point,
|
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -117,23 +93,7 @@ class W8A8IntLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
weight_scale = weight_scale.reshape(-1)
|
weight_scale = weight_scale.reshape(-1)
|
||||||
|
|
||||||
input_scale = None
|
|
||||||
if self.load_input_scale:
|
|
||||||
input_scale = weights.get_tensor(
|
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
input_zero_point = None
|
|
||||||
if self.load_input_zero_point:
|
|
||||||
input_zero_point = _get_tensor_or_else(
|
|
||||||
weights,
|
|
||||||
f"{prefix}.input_zero_point",
|
|
||||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
return Int8Weight(
|
return Int8Weight(
|
||||||
input_scale=input_scale,
|
|
||||||
input_zero_point=input_zero_point,
|
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -155,28 +115,7 @@ class W8A8IntLoader(WeightsLoader):
|
|||||||
]
|
]
|
||||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
|
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
|
||||||
|
|
||||||
input_scale = None
|
|
||||||
if self.load_input_scale:
|
|
||||||
input_scale = [
|
|
||||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False) for p in prefixes
|
|
||||||
]
|
|
||||||
input_scale = torch.cat(input_scale, dim=0)
|
|
||||||
|
|
||||||
input_zero_point = None
|
|
||||||
if self.load_input_zero_point:
|
|
||||||
input_zero_point = [
|
|
||||||
_get_tensor_or_else(
|
|
||||||
weights,
|
|
||||||
f"{prefix}.input_zero_point",
|
|
||||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
|
||||||
)
|
|
||||||
for prefix in prefixes
|
|
||||||
]
|
|
||||||
input_zero_point = torch.cat(input_zero_point, dim=0)
|
|
||||||
|
|
||||||
return Int8Weight(
|
return Int8Weight(
|
||||||
input_scale=input_scale,
|
|
||||||
input_zero_point=input_zero_point,
|
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -191,23 +130,7 @@ class W8A8IntLoader(WeightsLoader):
|
|||||||
f"{prefix}.weight_scale", to_dtype=False
|
f"{prefix}.weight_scale", to_dtype=False
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
input_scale = None
|
|
||||||
if self.load_input_scale:
|
|
||||||
input_scale = weights.get_tensor(
|
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
input_zero_point = None
|
|
||||||
if self.load_input_zero_point:
|
|
||||||
input_zero_point = _get_tensor_or_else(
|
|
||||||
weights,
|
|
||||||
f"{prefix}.input_zero_point",
|
|
||||||
torch.zeros((1,), device=w.device, dtype=torch.int8),
|
|
||||||
).reshape(-1)
|
|
||||||
|
|
||||||
return Int8Weight(
|
return Int8Weight(
|
||||||
input_scale=input_scale,
|
|
||||||
input_zero_point=input_zero_point,
|
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -230,8 +153,6 @@ def _get_tensor_or_else(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Int8Weight(Weight):
|
class Int8Weight(Weight):
|
||||||
input_scale: Optional[torch.Tensor]
|
|
||||||
input_zero_point: Optional[torch.Tensor]
|
|
||||||
input_symmetric: bool
|
input_symmetric: bool
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
weight_scale: Optional[torch.Tensor]
|
weight_scale: Optional[torch.Tensor]
|
||||||
@ -242,8 +163,6 @@ class Int8Weight(Weight):
|
|||||||
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
|
qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight)
|
||||||
return W8A8IntLinear(
|
return W8A8IntLinear(
|
||||||
bias=bias,
|
bias=bias,
|
||||||
input_scale=self.input_scale,
|
|
||||||
input_zero_point=self.input_zero_point,
|
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
weight=qweight,
|
weight=qweight,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -251,8 +170,6 @@ class Int8Weight(Weight):
|
|||||||
else:
|
else:
|
||||||
return W8A8IntLinear(
|
return W8A8IntLinear(
|
||||||
bias=bias,
|
bias=bias,
|
||||||
input_scale=self.input_scale,
|
|
||||||
input_zero_point=self.input_zero_point,
|
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
weight_scale=self.weight_scale,
|
weight_scale=self.weight_scale,
|
||||||
@ -264,17 +181,12 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
input_scale: Optional[torch.Tensor],
|
|
||||||
input_zero_point: Optional[torch.Tensor],
|
|
||||||
input_symmetric: bool,
|
input_symmetric: bool,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
input_scale = (
|
|
||||||
input_scale.to(torch.float32) if input_scale is not None else input_scale
|
|
||||||
)
|
|
||||||
weight_scale = weight_scale.to(torch.float32)
|
weight_scale = weight_scale.to(torch.float32)
|
||||||
|
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
@ -283,35 +195,6 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
self.weight = weight.t()
|
self.weight = weight.t()
|
||||||
self.weight_scale = weight_scale
|
self.weight_scale = weight_scale
|
||||||
|
|
||||||
if input_scale is not None:
|
|
||||||
if input_zero_point is None:
|
|
||||||
# Symmetric: simply use the largest scale to cover fused layers.
|
|
||||||
input_scale = input_scale.max()
|
|
||||||
else:
|
|
||||||
# Asymmetric: find the range that contains all individual ranges.
|
|
||||||
input_zero_point = input_zero_point.to(torch.int32)
|
|
||||||
int8_info = torch.iinfo(torch.int8)
|
|
||||||
|
|
||||||
# Find the most extreme values of all zero point/input scale
|
|
||||||
# pairs.
|
|
||||||
range_min = (input_scale * (int8_info.min - input_zero_point)).min()
|
|
||||||
range_max = (input_scale * (int8_info.max - input_zero_point)).max()
|
|
||||||
|
|
||||||
# Calculate new scale and zero point.
|
|
||||||
input_scale = (range_max - range_min) / (int8_info.max - int8_info.min)
|
|
||||||
input_zero_point = int8_info.min - (range_min / input_scale)
|
|
||||||
input_zero_point = input_zero_point.to(torch.int32)
|
|
||||||
|
|
||||||
self.range_min = (
|
|
||||||
input_scale * (int8_info.min - input_zero_point)
|
|
||||||
).min()
|
|
||||||
self.range_max = (
|
|
||||||
input_scale * (int8_info.max - input_zero_point)
|
|
||||||
).max()
|
|
||||||
|
|
||||||
self.input_scale = input_scale
|
|
||||||
self.input_zero_point = input_zero_point
|
|
||||||
|
|
||||||
if input_symmetric:
|
if input_symmetric:
|
||||||
self.zero_point_adj = None
|
self.zero_point_adj = None
|
||||||
else:
|
else:
|
||||||
@ -320,16 +203,13 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
dim=0, keepdim=True, dtype=torch.int32
|
dim=0, keepdim=True, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_zero_point is not None:
|
|
||||||
self.zero_point_adj *= input_zero_point
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
|
qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant(
|
||||||
input=input,
|
input=input,
|
||||||
scale=self.input_scale,
|
scale=None,
|
||||||
azp=self.input_zero_point,
|
azp=None,
|
||||||
symmetric=self.input_symmetric,
|
symmetric=self.input_symmetric,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -343,7 +223,11 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.zero_point_adj is not None and input_scale is not None
|
assert (
|
||||||
|
self.zero_point_adj is not None
|
||||||
|
and input_scale is not None
|
||||||
|
and (self.input_symmetric or input_zero_point is not None)
|
||||||
|
)
|
||||||
|
|
||||||
return marlin_kernels.cutlass_scaled_mm_azp(
|
return marlin_kernels.cutlass_scaled_mm_azp(
|
||||||
a=qinput,
|
a=qinput,
|
||||||
@ -352,8 +236,6 @@ class W8A8IntLinear(torch.nn.Module):
|
|||||||
scale_b=self.weight_scale,
|
scale_b=self.weight_scale,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
azp_adj=self.zero_point_adj,
|
azp_adj=self.zero_point_adj,
|
||||||
# Zero point is already in the adjustment when using static
|
azp=input_zero_point,
|
||||||
# input quantization.
|
|
||||||
azp=input_zero_point if self.input_zero_point is None else None,
|
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user