mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Support loading local kernels for development
This commit is contained in:
parent
b35ab54fd4
commit
d39f896c5c
@ -1,7 +1,7 @@
|
|||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
@ -108,7 +108,9 @@ def paged_attention(
|
|||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
attention_kernels = load_kernel("kernels-community/attention")
|
attention_kernels = load_kernel(
|
||||||
|
module="attention", repo_id="kernels-community/attention"
|
||||||
|
)
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import fp8_quantize
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
@ -222,7 +222,9 @@ def paged_reshape_and_cache(
|
|||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
try:
|
try:
|
||||||
attention_kernels = load_kernel("kernels-community/attention")
|
attention_kernels = load_kernel(
|
||||||
|
module="attention", repo_id="kernels-community/attention"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
from typing import List, Optional, Union, TypeVar
|
from typing import List, Optional, Union, TypeVar
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
marlin_kernels = load_kernel(
|
||||||
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
@ -2,11 +2,11 @@ from dataclasses import dataclass
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple, Type, Union, List
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
Weight,
|
Weight,
|
||||||
WeightsLoader,
|
WeightsLoader,
|
||||||
@ -16,7 +16,9 @@ from text_generation_server.utils.weights import (
|
|||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
try:
|
try:
|
||||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
marlin_kernels = load_kernel(
|
||||||
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
@ -2,16 +2,18 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from hf_kernels import load_kernel
|
|
||||||
from text_generation_server.layers.fp8 import fp8_quantize
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||||
from text_generation_server.layers.marlin.util import (
|
from text_generation_server.layers.marlin.util import (
|
||||||
_check_marlin_kernels,
|
_check_marlin_kernels,
|
||||||
permute_scales,
|
permute_scales,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
marlin_kernels = load_kernel(
|
||||||
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ from typing import List, Optional, Union
|
|||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from hf_kernels import load_kernel
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.layers.marlin.util import (
|
from text_generation_server.layers.marlin.util import (
|
||||||
_check_marlin_kernels,
|
_check_marlin_kernels,
|
||||||
@ -13,11 +12,14 @@ from text_generation_server.layers.marlin.util import (
|
|||||||
unpack_cols,
|
unpack_cols,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
marlin_kernels = load_kernel(
|
||||||
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
marlin_kernels = load_kernel(
|
||||||
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
import functools
|
import functools
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
marlin_kernels = load_kernel("kernels-community/quantization")
|
marlin_kernels = load_kernel(
|
||||||
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from typing import Optional, Protocol, runtime_checkable
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -19,6 +18,7 @@ from text_generation_server.layers.moe.gptq_marlin import (
|
|||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
@ -29,7 +29,7 @@ from text_generation_server.utils.weights import (
|
|||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
moe_kernels = load_kernel("kernels-community/moe")
|
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||||
fused_topk = moe_kernels.fused_topk
|
fused_topk = moe_kernels.fused_topk
|
||||||
grouped_topk = moe_kernels.grouped_topk
|
grouped_topk = moe_kernels.grouped_topk
|
||||||
else:
|
else:
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.layers import moe
|
from text_generation_server.layers import moe
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
from text_generation_server.layers.marlin.gptq import (
|
||||||
GPTQMarlinWeight,
|
GPTQMarlinWeight,
|
||||||
@ -14,7 +14,7 @@ from text_generation_server.layers.marlin.gptq import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
moe_kernels = load_kernel("kernels-community/moe")
|
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||||
else:
|
else:
|
||||||
moe_kernels = None
|
moe_kernels = None
|
||||||
|
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
elif SYSTEM == "cuda":
|
elif SYSTEM == "cuda":
|
||||||
moe_kernels = load_kernel("kernels-community/moe")
|
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||||
else:
|
else:
|
||||||
import moe_kernels
|
import moe_kernels
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -23,11 +22,12 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
elif SYSTEM == "cuda":
|
elif SYSTEM == "cuda":
|
||||||
moe_kernels = load_kernel("kernels-community/moe")
|
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
|
||||||
else:
|
else:
|
||||||
import moe_kernels
|
import moe_kernels
|
||||||
|
|
||||||
|
22
server/text_generation_server/utils/kernels.py
Normal file
22
server/text_generation_server/utils/kernels.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from hf_kernels import load_kernel as hf_load_kernel
|
||||||
|
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
|
def load_kernel(*, module: str, repo_id: str):
|
||||||
|
"""
|
||||||
|
Load a kernel. First try to load it as the given module (e.g. for
|
||||||
|
local development), falling back to a locked Hub kernel.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
m = importlib.import_module(module)
|
||||||
|
log_once(logger.info, f"Using local module for `{module}`")
|
||||||
|
return m
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return hf_load_kernel(repo_id=repo_id)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["load_kernel"]
|
Loading…
Reference in New Issue
Block a user