diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 6fece12b..f16004cd 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,7 +1,7 @@ -from hf_kernels import load_kernel import torch from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, @@ -108,7 +108,9 @@ def paged_attention( if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths + seqlen.cache_lengths - attention_kernels = load_kernel("kernels-community/attention") + attention_kernels = load_kernel( + module="attention", repo_id="kernels-community/attention" + ) out = torch.empty_like(query) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 6f1fa9a7..522b10a7 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -1,13 +1,13 @@ from typing import Tuple from dataclasses import dataclass, field -from hf_kernels import load_kernel from loguru import logger import torch from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weights @@ -222,7 +222,9 @@ def paged_reshape_and_cache( if SYSTEM == "cuda": try: - attention_kernels = load_kernel("kernels-community/attention") + attention_kernels = load_kernel( + module="attention", repo_id="kernels-community/attention" + ) except Exception as e: raise ImportError( f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" diff --git a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py index 8da8c8a0..dc0783b7 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py +++ b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py @@ -1,17 +1,19 @@ from typing import List, Optional, Union, TypeVar from dataclasses import dataclass -from hf_kernels import load_kernel from loguru import logger import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationType from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: - marlin_kernels = load_kernel("kernels-community/quantization") + marlin_kernels = load_kernel( + module="quantization", repo_id="kernels-community/quantization" + ) except ImportError: marlin_kernels = None diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 60cafc1d..c0ef78c3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -2,11 +2,11 @@ from dataclasses import dataclass import os from typing import Optional, Tuple, Type, Union, List -from hf_kernels import load_kernel import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import ( Weight, WeightsLoader, @@ -16,7 +16,9 @@ from text_generation_server.utils.weights import ( from text_generation_server.utils.log import log_once try: - marlin_kernels = load_kernel("kernels-community/quantization") + marlin_kernels = load_kernel( + module="quantization", repo_id="kernels-community/quantization" + ) except ImportError: marlin_kernels = None diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index 851b5edb..0c689aa8 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -2,16 +2,18 @@ from typing import Optional import torch import torch.nn as nn -from hf_kernels import load_kernel from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.util import ( _check_marlin_kernels, permute_scales, ) +from text_generation_server.utils.kernels import load_kernel try: - marlin_kernels = load_kernel("kernels-community/quantization") + marlin_kernels = load_kernel( + module="quantization", repo_id="kernels-community/quantization" + ) except ImportError: marlin_kernels = None diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 88f12bde..24ea69e7 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -4,7 +4,6 @@ from typing import List, Optional, Union import numpy import torch import torch.nn as nn -from hf_kernels import load_kernel from loguru import logger from text_generation_server.layers.marlin.util import ( _check_marlin_kernels, @@ -13,11 +12,14 @@ from text_generation_server.layers.marlin.util import ( unpack_cols, ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: - marlin_kernels = load_kernel("kernels-community/quantization") + marlin_kernels = load_kernel( + module="quantization", repo_id="kernels-community/quantization" + ) except ImportError: marlin_kernels = None diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py index 5e828fce..06991981 100644 --- a/server/text_generation_server/layers/marlin/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -1,14 +1,17 @@ from dataclasses import dataclass from typing import List, Optional, Union -from hf_kernels import load_kernel import torch import torch.nn as nn + from text_generation_server.layers.marlin.util import _check_marlin_kernels +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: - marlin_kernels = load_kernel("kernels-community/quantization") + marlin_kernels = load_kernel( + module="quantization", repo_id="kernels-community/quantization" + ) except ImportError: marlin_kernels = None diff --git a/server/text_generation_server/layers/marlin/util.py b/server/text_generation_server/layers/marlin/util.py index 380a3246..f22e9c9c 100644 --- a/server/text_generation_server/layers/marlin/util.py +++ b/server/text_generation_server/layers/marlin/util.py @@ -1,13 +1,15 @@ import functools from typing import List, Tuple -from hf_kernels import load_kernel import numpy import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel try: - marlin_kernels = load_kernel("kernels-community/quantization") + marlin_kernels = load_kernel( + module="quantization", repo_id="kernels-community/quantization" + ) except ImportError: marlin_kernels = None diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index c19a2847..abf7ba6b 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -1,6 +1,5 @@ from typing import Optional, Protocol, runtime_checkable -from hf_kernels import load_kernel import torch import torch.nn as nn from loguru import logger @@ -19,6 +18,7 @@ from text_generation_server.layers.moe.gptq_marlin import ( from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, @@ -29,7 +29,7 @@ from text_generation_server.utils.weights import ( if SYSTEM == "ipex": from .fused_moe_ipex import fused_topk, grouped_topk if SYSTEM == "cuda": - moe_kernels = load_kernel("kernels-community/moe") + moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") fused_topk = moe_kernels.fused_topk grouped_topk = moe_kernels.grouped_topk else: diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 258cf76e..0819c2f5 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -1,12 +1,12 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional -from hf_kernels import load_kernel import torch import torch.nn as nn from text_generation_server.layers import moe from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import Weights from text_generation_server.layers.marlin.gptq import ( GPTQMarlinWeight, @@ -14,7 +14,7 @@ from text_generation_server.layers.marlin.gptq import ( ) if SYSTEM == "cuda": - moe_kernels = load_kernel("kernels-community/moe") + moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") else: moe_kernels = None diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index b162ec5e..bdef06c6 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -1,16 +1,16 @@ from typing import Optional -from hf_kernels import load_kernel import torch import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.weights import UnquantizedWeight, Weights if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE elif SYSTEM == "cuda": - moe_kernels = load_kernel("kernels-community/moe") + moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") else: import moe_kernels diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 90138361..225fae77 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hf_kernels import load_kernel import torch import torch.distributed @@ -23,11 +22,12 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.kernels import load_kernel if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE elif SYSTEM == "cuda": - moe_kernels = load_kernel("kernels-community/moe") + moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe") else: import moe_kernels diff --git a/server/text_generation_server/utils/kernels.py b/server/text_generation_server/utils/kernels.py new file mode 100644 index 00000000..42745c71 --- /dev/null +++ b/server/text_generation_server/utils/kernels.py @@ -0,0 +1,22 @@ +import importlib + +from loguru import logger +from hf_kernels import load_kernel as hf_load_kernel + +from text_generation_server.utils.log import log_once + + +def load_kernel(*, module: str, repo_id: str): + """ + Load a kernel. First try to load it as the given module (e.g. for + local development), falling back to a locked Hub kernel. + """ + try: + m = importlib.import_module(module) + log_once(logger.info, f"Using local module for `{module}`") + return m + except ModuleNotFoundError: + return hf_load_kernel(repo_id=repo_id) + + +__all__ = ["load_kernel"]