diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index dca7fa95..f7710145 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -123,12 +123,12 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .expand(w.shape[0]) ) - try: + + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): input_scale = weights.get_tensor( f"{prefix}.input_scale", to_dtype=False ).reshape(-1) - except Exception: - input_scale = None return Fp8Weight( weight=w, @@ -163,7 +163,9 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) scale = scale.reshape(-1).expand(w.shape[0]) - try: + + input_scale = None + if weights.get_tensor(f"{prefix}.input_scale"): input_scale = weights.get_tensor( f"{prefix}.input_scale", to_dtype=False ) @@ -175,8 +177,6 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) input_scale = input_scale.reshape(-1).max() - except Exception: - input_scale = None return Fp8Weight( weight=w, @@ -207,14 +207,17 @@ class HybridFP8UnquantLoader(WeightsLoader): for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) - try: - input_scale = [ - _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) - for p, shape in zip(prefixes, shapes) - ] - input_scale = torch.cat(input_scale, dim=0).reshape(-1).max() - except Exception: - input_scale = None + + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) return Fp8Weight( weight=w, @@ -237,12 +240,11 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .expand(w.shape[0]) ) - try: + input_scale = None + if weights.has_tensor(f"{prefix}.input_scale"): input_scale = weights.get_tensor( f"{prefix}.input_scale", to_dtype=False ).reshape(-1) - except Exception: - input_scale = None return Fp8Weight( weight=w, @@ -272,12 +274,12 @@ class Fp8Weight(Weight): # memory. Can be non-contiguous when we e.g. expand from scalars. self.weight_scale = self.weight_scale.contiguous() return get_fp8_linear().from_fp8( - self.weight, - self.weight_scale, - self.input_scale, - self.activation_scale_ub, - bias, - self.dtype, + weight=self.weight, + scale=self.weight_scale, + dtype=self.dtype, + bias=bias, + input_scale=self.input_scale, + scale_upper_bound=self.activation_scale_ub, ) @@ -286,12 +288,12 @@ class Fp8Linear(torch.nn.Module): def __init__( self, - qweight, - scale, - input_scale, - scale_upper_bound, - bias, - dtype, + qweight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + scale_upper_bound: Optional[float] = None, ) -> None: super().__init__() if FBGEMM_MM_AVAILABLE: @@ -327,14 +329,24 @@ class Fp8Linear(torch.nn.Module): return cls( qweight=qweight, scale=scale, + dtype=dtype, + bias=bias, input_scale=None, scale_upper_bound=None, - bias=bias, - dtype=dtype, ) @classmethod - def from_fp8(cls, weight, scale, input_scale, scale_upper_bound, bias, dtype): + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + **kwargs, + ) -> "Fp8Linear": + input_scale = kwargs.get("input_scale", None) + scale_upper_bound = kwargs.get("scale_upper_bound", None) + if FBGEMM_DYN_AVAILABLE: # fbgemm needs float32 scales. scale = scale.float() @@ -391,7 +403,7 @@ class Fp8Linear(torch.nn.Module): bias=self.bias, ) - if type(output) is tuple and len(output) == 2: + if isinstance(output, tuple) and len(output) == 2: output = output[0] else: device_identity = None @@ -405,7 +417,7 @@ class Fp8Linear(torch.nn.Module): scale_b=device_identity, out_dtype=torch.float32, ) - if type(output) is tuple and len(output) == 2: + if isinstance(output, tuple) and len(output) == 2: output = output[0] output = output * scale * self.scale.t() diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index dac109cf..49f5c480 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -62,7 +62,14 @@ class GPTQMarlinFP8Linear(nn.Module): return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype): + def from_fp8( + cls, + weight: torch.Tensor, + scale: torch.Tensor, + bias: torch.Tensor, + dtype: torch.dtype, + **kwargs, + ): return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 570ea853..c9886092 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -342,22 +342,19 @@ def get_model( model_type = config_dict.get("model_type", None) quantization_config = config_dict.get("quantization_config", None) - compression_config = config_dict.get("compression_config", None) if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) + config_groups = quantization_config.get("config_groups", None) if method in {"gptq", "awq", "exl2"}: log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method elif method == "fbgemm_fp8" or method == "fp8": log_master(logger.info, "Auto selecting quantization method fp8") quantize = "fp8" - else: - log_master(logger.warning, f"Unknown quantization method {method}") - elif compression_config is not None: - # TODO: at some point we should probably fully parse the compression - # configuration to know which parameters are compressed. - config_groups = compression_config.get("config_groups") - if config_groups is not None: + elif config_groups is not None: + # Compression config renamed to quantization_config + # TODO: at some point we should probably fully parse the compression + # configuration to know which parameters are compressed. for _, group in config_groups.items(): weights_config = group.get("weights") if weights_config is not None: @@ -370,6 +367,8 @@ def get_model( ) quantize = "fp8" break + else: + log_master(logger.warning, f"Unknown quantization method {method}") if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 75e01f7c..548591e5 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -197,7 +197,7 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ - def _has_tensor(self, tensor_name: str): + def has_tensor(self, tensor_name: str): try: self.get_filename(tensor_name) except Exception: