diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 8328ad30..9bfdae5e 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,11 +1,11 @@ from dataclasses import dataclass import os from typing import Optional, Tuple, Type, Union, List -from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 import torch from loguru import logger +from moe_kernels import w8a8_block_fp8_matmul, per_token_group_quant_fp8 from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import ( Weight, @@ -187,8 +187,6 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: if self.weight_block_size is not None: scale = weights.get_tensor(f"{prefix}.weight_scale_inv") - if scale.device == torch.device("cpu"): - scale = scale.to(weights.device) return Fp8Weight( weight=w, weight_scale=scale, @@ -289,9 +287,7 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) for p in prefixes ] - scale = torch.cat(scale, dim=dim) - if scale.device == torch.device("cpu"): - scale = scale.to(weights.device) + scale = torch.cat(scale, dim=dim).to(weights.device) return Fp8Weight( weight=w, weight_scale=scale, @@ -347,8 +343,6 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: if self.weight_block_size is not None: scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) - if scale.device == torch.device("cpu"): - scale = scale.to(weights.device) return Fp8Weight( weight=w, diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 764ffab4..e44cf64f 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -60,6 +60,7 @@ def _get_quantizer_config(model_id, revision): return _FP8QuantizerConfig( activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + weight_block_size = data["quantization_config"].get("weight_block_size", None) if "zero_point" in data["quantization_config"]: sym = not data["quantization_config"]["zero_point"] @@ -67,16 +68,12 @@ def _get_quantizer_config(model_id, revision): elif "sym" in data["quantization_config"]: sym = data["quantization_config"]["sym"] - if "bits" in data["quantization_config"]: - bits = data["quantization_config"]["bits"] - if "group_size" in data["quantization_config"]: - groupsize = data["quantization_config"]["group_size"] + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] - checkpoint_format = data["quantization_config"].get("checkpoint_format", None) - if desc_act in data["quantization_config"]: - desc_act = data["quantization_config"]["desc_act"] - weight_block_size = data["quantization_config"].get("weight_block_size", None) + checkpoint_format = data["quantization_config"].get("checkpoint_format") + desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index e1c9ab33..c03dd2b0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -149,10 +149,6 @@ class Weights: ): routing = {} for filename in filenames: - # if filename.as_posix().endswith("l.safetensors"): - # from loguru import logger - # logger.info(f"Skipping {filename}") - # continue with safe_open(filename, framework="pytorch") as f: for k in f.keys(): if k in routing: