From f76c0ff17fc180f4d550cc1fe8dc504e319e7f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 18 Nov 2024 10:54:51 +0000 Subject: [PATCH] Always use dynamic input quantization for w8a8 int It's far less flaky and gives better output. --- .../layers/compressed_tensors/w8a8_int.py | 140 ++---------------- 1 file changed, 11 insertions(+), 129 deletions(-) diff --git a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py index 4a711a09..9eeb6b83 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py +++ b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py @@ -37,22 +37,14 @@ class W8A8IntLoader(WeightsLoader): self.load_weight_scale = not weight_args.dynamic if input_args is not None: - static = not input_args.dynamic - symmetric = input_args.symmetric - self.load_input_scale = static - self.load_input_zero_point = static and not symmetric self.input_symmetric = input_args.symmetric - if static: - # People shouldn't really use static input quantization, - # the output is pretty bad. + if not input_args.dynamic: log_once( logger.warning, - "Using W8A8 int with static input quantization results in large regressions in accuracy. Consider dynamic input quantization instead.", + "Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).", ) else: - self.load_input_scale = False - self.load_input_zero_point = False self.input_symmetric = True def __str__(self) -> str: @@ -62,7 +54,7 @@ class W8A8IntLoader(WeightsLoader): def symmetric_to_sting(symmetric): return "symmetric" if symmetric else "asymmetric" - return f"{self.__class__.__name__} (w8a8 int, input: {symmetric_to_sting(self.input_symmetric)}, {scale_to_str(self.load_input_scale)})" + return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_sting(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))" def get_weights(self, weights: "Weights", prefix: str): w = weights.get_tensor(f"{prefix}.weight", to_dtype=False) @@ -73,23 +65,7 @@ class W8A8IntLoader(WeightsLoader): f"{prefix}.weight_scale", to_dtype=False ).reshape(-1) - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - input_zero_point = None - if self.load_input_zero_point: - input_zero_point = _get_tensor_or_else( - weights, - f"{prefix}.input_zero_point", - torch.zeros((1,), device=w.device, dtype=torch.int8), - ).reshape(-1) - return Int8Weight( - input_scale=input_scale, - input_zero_point=input_zero_point, input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, @@ -117,23 +93,7 @@ class W8A8IntLoader(WeightsLoader): ) weight_scale = weight_scale.reshape(-1) - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - input_zero_point = None - if self.load_input_zero_point: - input_zero_point = _get_tensor_or_else( - weights, - f"{prefix}.input_zero_point", - torch.zeros((1,), device=w.device, dtype=torch.int8), - ).reshape(-1) - return Int8Weight( - input_scale=input_scale, - input_zero_point=input_zero_point, input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, @@ -155,28 +115,7 @@ class W8A8IntLoader(WeightsLoader): ] weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1) - input_scale = None - if self.load_input_scale: - input_scale = [ - weights.get_tensor(f"{p}.input_scale", to_dtype=False) for p in prefixes - ] - input_scale = torch.cat(input_scale, dim=0) - - input_zero_point = None - if self.load_input_zero_point: - input_zero_point = [ - _get_tensor_or_else( - weights, - f"{prefix}.input_zero_point", - torch.zeros((1,), device=w.device, dtype=torch.int8), - ) - for prefix in prefixes - ] - input_zero_point = torch.cat(input_zero_point, dim=0) - return Int8Weight( - input_scale=input_scale, - input_zero_point=input_zero_point, input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, @@ -191,23 +130,7 @@ class W8A8IntLoader(WeightsLoader): f"{prefix}.weight_scale", to_dtype=False ).reshape(-1) - input_scale = None - if self.load_input_scale: - input_scale = weights.get_tensor( - f"{prefix}.input_scale", to_dtype=False - ).reshape(-1) - - input_zero_point = None - if self.load_input_zero_point: - input_zero_point = _get_tensor_or_else( - weights, - f"{prefix}.input_zero_point", - torch.zeros((1,), device=w.device, dtype=torch.int8), - ).reshape(-1) - return Int8Weight( - input_scale=input_scale, - input_zero_point=input_zero_point, input_symmetric=self.input_symmetric, weight=w, weight_scale=weight_scale, @@ -230,8 +153,6 @@ def _get_tensor_or_else( @dataclass class Int8Weight(Weight): - input_scale: Optional[torch.Tensor] - input_zero_point: Optional[torch.Tensor] input_symmetric: bool weight: torch.Tensor weight_scale: Optional[torch.Tensor] @@ -242,8 +163,6 @@ class Int8Weight(Weight): qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight) return W8A8IntLinear( bias=bias, - input_scale=self.input_scale, - input_zero_point=self.input_zero_point, input_symmetric=self.input_symmetric, weight=qweight, weight_scale=weight_scale, @@ -251,8 +170,6 @@ class Int8Weight(Weight): else: return W8A8IntLinear( bias=bias, - input_scale=self.input_scale, - input_zero_point=self.input_zero_point, input_symmetric=self.input_symmetric, weight=self.weight, weight_scale=self.weight_scale, @@ -264,17 +181,12 @@ class W8A8IntLinear(torch.nn.Module): self, *, bias: Optional[torch.Tensor], - input_scale: Optional[torch.Tensor], - input_zero_point: Optional[torch.Tensor], input_symmetric: bool, weight: torch.Tensor, weight_scale: torch.Tensor, ): super().__init__() - input_scale = ( - input_scale.to(torch.float32) if input_scale is not None else input_scale - ) weight_scale = weight_scale.to(torch.float32) self.bias = bias @@ -283,35 +195,6 @@ class W8A8IntLinear(torch.nn.Module): self.weight = weight.t() self.weight_scale = weight_scale - if input_scale is not None: - if input_zero_point is None: - # Symmetric: simply use the largest scale to cover fused layers. - input_scale = input_scale.max() - else: - # Asymmetric: find the range that contains all individual ranges. - input_zero_point = input_zero_point.to(torch.int32) - int8_info = torch.iinfo(torch.int8) - - # Find the most extreme values of all zero point/input scale - # pairs. - range_min = (input_scale * (int8_info.min - input_zero_point)).min() - range_max = (input_scale * (int8_info.max - input_zero_point)).max() - - # Calculate new scale and zero point. - input_scale = (range_max - range_min) / (int8_info.max - int8_info.min) - input_zero_point = int8_info.min - (range_min / input_scale) - input_zero_point = input_zero_point.to(torch.int32) - - self.range_min = ( - input_scale * (int8_info.min - input_zero_point) - ).min() - self.range_max = ( - input_scale * (int8_info.max - input_zero_point) - ).max() - - self.input_scale = input_scale - self.input_zero_point = input_zero_point - if input_symmetric: self.zero_point_adj = None else: @@ -320,16 +203,13 @@ class W8A8IntLinear(torch.nn.Module): dim=0, keepdim=True, dtype=torch.int32 ) - if input_zero_point is not None: - self.zero_point_adj *= input_zero_point - def forward(self, input: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant( input=input, - scale=self.input_scale, - azp=self.input_zero_point, + scale=None, + azp=None, symmetric=self.input_symmetric, ) @@ -343,7 +223,11 @@ class W8A8IntLinear(torch.nn.Module): bias=self.bias, ) else: - assert self.zero_point_adj is not None and input_scale is not None + assert ( + self.zero_point_adj is not None + and input_scale is not None + and (self.input_symmetric or input_zero_point is not None) + ) return marlin_kernels.cutlass_scaled_mm_azp( a=qinput, @@ -352,8 +236,6 @@ class W8A8IntLinear(torch.nn.Module): scale_b=self.weight_scale, out_dtype=input.dtype, azp_adj=self.zero_point_adj, - # Zero point is already in the adjustment when using static - # input quantization. - azp=input_zero_point if self.input_zero_point is None else None, + azp=input_zero_point, bias=self.bias, )