Fixup some imports

This commit is contained in:
Daniël de Kok 2025-02-04 13:22:24 +00:00
parent a60d1e614f
commit f25a7aad89
6 changed files with 17 additions and 12 deletions

View File

@ -6,15 +6,16 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: if SYSTEM == "cuda":
marlin_kernels = load_kernel( marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
except ImportError: else:
marlin_kernels = None marlin_kernels = None

View File

@ -15,14 +15,15 @@ from text_generation_server.utils.weights import (
) )
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
try: if SYSTEM == "cuda":
marlin_kernels = load_kernel( marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
except ImportError: else:
marlin_kernels = None marlin_kernels = None
try: try:
# TODO: needs to be ported over to MoE and used on CUDA.
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
except ImportError: except ImportError:
w8a8_block_fp8_matmul = None w8a8_block_fp8_matmul = None

View File

@ -8,13 +8,14 @@ from text_generation_server.layers.marlin.util import (
_check_marlin_kernels, _check_marlin_kernels,
permute_scales, permute_scales,
) )
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
try: if SYSTEM == "cuda":
marlin_kernels = load_kernel( marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
except ImportError: else:
marlin_kernels = None marlin_kernels = None

View File

@ -16,13 +16,14 @@ from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: if SYSTEM == "cuda":
marlin_kernels = load_kernel( marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
except ImportError: else:
marlin_kernels = None marlin_kernels = None
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8 has_sm_8_0 = major >= 8

View File

@ -5,14 +5,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: if SYSTEM == "cuda":
marlin_kernels = load_kernel( marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
except ImportError: else:
marlin_kernels = None marlin_kernels = None

View File

@ -6,11 +6,11 @@ import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
try: if SYSTEM == "cuda":
marlin_kernels = load_kernel( marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization" module="quantization", repo_id="kernels-community/quantization"
) )
except ImportError: else:
marlin_kernels = None marlin_kernels = None
try: try: