mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Support loading local kernels for development
This commit is contained in:
parent
b35ab54fd4
commit
d39f896c5c
@ -1,7 +1,7 @@
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
@ -108,7 +108,9 @@ def paged_attention(
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
attention_kernels = load_kernel("kernels-community/attention")
|
||||
attention_kernels = load_kernel(
|
||||
module="attention", repo_id="kernels-community/attention"
|
||||
)
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weights
|
||||
|
||||
@ -222,7 +222,9 @@ def paged_reshape_and_cache(
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
attention_kernels = load_kernel("kernels-community/attention")
|
||||
attention_kernels = load_kernel(
|
||||
module="attention", repo_id="kernels-community/attention"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
|
@ -1,17 +1,19 @@
|
||||
from typing import List, Optional, Union, TypeVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
from loguru import logger
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Optional, Tuple, Type, Union, List
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
WeightsLoader,
|
||||
@ -16,7 +16,9 @@ from text_generation_server.utils.weights import (
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
try:
|
||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
|
@ -2,16 +2,18 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hf_kernels import load_kernel
|
||||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||
from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
permute_scales,
|
||||
)
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
try:
|
||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
|
@ -4,7 +4,6 @@ from typing import List, Optional, Union
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hf_kernels import load_kernel
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
@ -13,11 +12,14 @@ from text_generation_server.layers.marlin.util import (
|
||||
unpack_cols,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
|
@ -1,14 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
|
@ -1,13 +1,15 @@
|
||||
import functools
|
||||
from typing import List, Tuple
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import numpy
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
try:
|
||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
@ -19,6 +18,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
|
||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
@ -29,7 +29,7 @@ from text_generation_server.utils.weights import (
|
||||
if SYSTEM == "ipex":
|
||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||
if SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||
fused_topk = moe_kernels.fused_topk
|
||||
grouped_topk = moe_kernels.grouped_topk
|
||||
else:
|
||||
|
@ -1,12 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers import moe
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinWeight,
|
||||
@ -14,7 +14,7 @@ from text_generation_server.layers.marlin.gptq import (
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||
else:
|
||||
moe_kernels = None
|
||||
|
||||
|
@ -1,16 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
elif SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||
else:
|
||||
import moe_kernels
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -23,11 +22,12 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
elif SYSTEM == "cuda":
|
||||
moe_kernels = load_kernel("kernels-community/moe")
|
||||
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||
else:
|
||||
import moe_kernels
|
||||
|
||||
|
22
server/text_generation_server/utils/kernels.py
Normal file
22
server/text_generation_server/utils/kernels.py
Normal file
@ -0,0 +1,22 @@
|
||||
import importlib
|
||||
|
||||
from loguru import logger
|
||||
from hf_kernels import load_kernel as hf_load_kernel
|
||||
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
def load_kernel(*, module: str, repo_id: str):
|
||||
"""
|
||||
Load a kernel. First try to load it as the given module (e.g. for
|
||||
local development), falling back to a locked Hub kernel.
|
||||
"""
|
||||
try:
|
||||
m = importlib.import_module(module)
|
||||
log_once(logger.info, f"Using local module for `{module}`")
|
||||
return m
|
||||
except ModuleNotFoundError:
|
||||
return hf_load_kernel(repo_id=repo_id)
|
||||
|
||||
|
||||
__all__ = ["load_kernel"]
|
Loading…
Reference in New Issue
Block a user