diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index a6013361..c3c038fe 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -45,6 +45,13 @@ def _get_quantizer_config(model_id, revision): filename = hf_hub_download(model_id, filename=filename, revision=revision) with open(filename, "r") as f: data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models @@ -69,17 +76,6 @@ def _get_quantizer_config(model_id, revision): desc_act = data["desc_act"] if "version" in data and data["version"] == "GEMM": quant_method = "awq" - # FP8 config - except KeyError: - try: - filename = os.path.join(model_id, filename) - with open(filename, "r") as f: - data = json.load(f) - return _FP8QuantizerConfig( - activation_scale_ub=data["activation_scale_ub"] - ) - except: - pass except Exception: filename = "quant_config.json" try: