From f25a7aad89a3929e4137c180e3533b71868dee19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 4 Feb 2025 13:22:24 +0000 Subject: [PATCH] Fixup some imports --- .../layers/compressed_tensors/w8a8_int.py | 5 +++-- server/text_generation_server/layers/fp8.py | 5 +++-- server/text_generation_server/layers/marlin/fp8.py | 5 +++-- server/text_generation_server/layers/marlin/gptq.py | 5 +++-- server/text_generation_server/layers/marlin/marlin.py | 5 +++-- server/text_generation_server/layers/marlin/util.py | 4 ++-- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py index dc0783b7..e9e3e975 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py +++ b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py @@ -6,15 +6,16 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationType from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: +if SYSTEM == "cuda": marlin_kernels = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) -except ImportError: +else: marlin_kernels = None diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index c0ef78c3..d412c5a4 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -15,14 +15,15 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.utils.log import log_once -try: +if SYSTEM == "cuda": marlin_kernels = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) -except ImportError: +else: marlin_kernels = None try: + # TODO: needs to be ported over to MoE and used on CUDA. from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 except ImportError: w8a8_block_fp8_matmul = None diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index 0c689aa8..48f5289f 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -8,13 +8,14 @@ from text_generation_server.layers.marlin.util import ( _check_marlin_kernels, permute_scales, ) +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel -try: +if SYSTEM == "cuda": marlin_kernels = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) -except ImportError: +else: marlin_kernels = None diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 24ea69e7..066724e2 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -16,13 +16,14 @@ from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: +if SYSTEM == "cuda": marlin_kernels = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) -except ImportError: +else: marlin_kernels = None + try: major, _minor = torch.cuda.get_device_capability() has_sm_8_0 = major >= 8 diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py index 06991981..f01f6af2 100644 --- a/server/text_generation_server/layers/marlin/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -5,14 +5,15 @@ import torch import torch.nn as nn from text_generation_server.layers.marlin.util import _check_marlin_kernels +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: +if SYSTEM == "cuda": marlin_kernels = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) -except ImportError: +else: marlin_kernels = None diff --git a/server/text_generation_server/layers/marlin/util.py b/server/text_generation_server/layers/marlin/util.py index f22e9c9c..e73b5397 100644 --- a/server/text_generation_server/layers/marlin/util.py +++ b/server/text_generation_server/layers/marlin/util.py @@ -6,11 +6,11 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel -try: +if SYSTEM == "cuda": marlin_kernels = load_kernel( module="quantization", repo_id="kernels-community/quantization" ) -except ImportError: +else: marlin_kernels = None try: