Support loading local kernels for development

This commit is contained in:
Daniël de Kok 2025-02-04 11:20:56 +00:00
parent b35ab54fd4
commit d39f896c5c
13 changed files with 63 additions and 24 deletions

View File

@ -1,7 +1,7 @@
from hf_kernels import load_kernel
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
@ -108,7 +108,9 @@ def paged_attention(
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths + seqlen.cache_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
attention_kernels = load_kernel("kernels-community/attention") attention_kernels = load_kernel(
module="attention", repo_id="kernels-community/attention"
)
out = torch.empty_like(query) out = torch.empty_like(query)

View File

@ -1,13 +1,13 @@
from typing import Tuple from typing import Tuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from hf_kernels import load_kernel
from loguru import logger from loguru import logger
import torch import torch
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
@ -222,7 +222,9 @@ def paged_reshape_and_cache(
if SYSTEM == "cuda": if SYSTEM == "cuda":
try: try:
attention_kernels = load_kernel("kernels-community/attention") attention_kernels = load_kernel(
module="attention", repo_id="kernels-community/attention"
)
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"

View File

@ -1,17 +1,19 @@
from typing import List, Optional, Union, TypeVar from typing import List, Optional, Union, TypeVar
from dataclasses import dataclass from dataclasses import dataclass
from hf_kernels import load_kernel
from loguru import logger from loguru import logger
import torch import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: try:
marlin_kernels = load_kernel("kernels-community/quantization") marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None

View File

@ -2,11 +2,11 @@ from dataclasses import dataclass
import os import os
from typing import Optional, Tuple, Type, Union, List from typing import Optional, Tuple, Type, Union, List
from hf_kernels import load_kernel
import torch import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
Weight, Weight,
WeightsLoader, WeightsLoader,
@ -16,7 +16,9 @@ from text_generation_server.utils.weights import (
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
try: try:
marlin_kernels = load_kernel("kernels-community/quantization") marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None

View File

@ -2,16 +2,18 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from hf_kernels import load_kernel
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import ( from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
permute_scales, permute_scales,
) )
from text_generation_server.utils.kernels import load_kernel
try: try:
marlin_kernels = load_kernel("kernels-community/quantization") marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Union
import numpy import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
from hf_kernels import load_kernel
from loguru import logger from loguru import logger
from text_generation_server.layers.marlin.util import ( from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
@ -13,11 +12,14 @@ from text_generation_server.layers.marlin.util import (
unpack_cols, unpack_cols,
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: try:
marlin_kernels = load_kernel("kernels-community/quantization") marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None

View File

@ -1,14 +1,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from hf_kernels import load_kernel
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: try:
marlin_kernels = load_kernel("kernels-community/quantization") marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None

View File

@ -1,13 +1,15 @@
import functools import functools
from typing import List, Tuple from typing import List, Tuple
from hf_kernels import load_kernel
import numpy import numpy
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
try: try:
marlin_kernels = load_kernel("kernels-community/quantization") marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None

View File

@ -1,6 +1,5 @@
from typing import Optional, Protocol, runtime_checkable from typing import Optional, Protocol, runtime_checkable
from hf_kernels import load_kernel
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger from loguru import logger
@ -19,6 +18,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
@ -29,7 +29,7 @@ from text_generation_server.utils.weights import (
if SYSTEM == "ipex": if SYSTEM == "ipex":
from .fused_moe_ipex import fused_topk, grouped_topk from .fused_moe_ipex import fused_topk, grouped_topk
if SYSTEM == "cuda": if SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe") moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
fused_topk = moe_kernels.fused_topk fused_topk = moe_kernels.fused_topk
grouped_topk = moe_kernels.grouped_topk grouped_topk = moe_kernels.grouped_topk
else: else:

View File

@ -1,12 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
from hf_kernels import load_kernel
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers import moe from text_generation_server.layers import moe
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import ( from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight, GPTQMarlinWeight,
@ -14,7 +14,7 @@ from text_generation_server.layers.marlin.gptq import (
) )
if SYSTEM == "cuda": if SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe") moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else: else:
moe_kernels = None moe_kernels = None

View File

@ -1,16 +1,16 @@
from typing import Optional from typing import Optional
from hf_kernels import load_kernel
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda": elif SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe") moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else: else:
import moe_kernels import moe_kernels

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from hf_kernels import load_kernel
import torch import torch
import torch.distributed import torch.distributed
@ -23,11 +22,12 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda": elif SYSTEM == "cuda":
moe_kernels = load_kernel("kernels-community/moe") moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else: else:
import moe_kernels import moe_kernels

View File

@ -0,0 +1,22 @@
import importlib
from loguru import logger
from hf_kernels import load_kernel as hf_load_kernel
from text_generation_server.utils.log import log_once
def load_kernel(*, module: str, repo_id: str):
"""
Load a kernel. First try to load it as the given module (e.g. for
local development), falling back to a locked Hub kernel.
"""
try:
m = importlib.import_module(module)
log_once(logger.info, f"Using local module for `{module}`")
return m
except ModuleNotFoundError:
return hf_load_kernel(repo_id=repo_id)
__all__ = ["load_kernel"]