mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixup some imports
This commit is contained in:
parent
a60d1e614f
commit
f25a7aad89
@ -6,15 +6,16 @@ import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
else:
|
||||
marlin_kernels = None
|
||||
|
||||
|
||||
|
@ -15,14 +15,15 @@ from text_generation_server.utils.weights import (
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
try:
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
else:
|
||||
marlin_kernels = None
|
||||
|
||||
try:
|
||||
# TODO: needs to be ported over to MoE and used on CUDA.
|
||||
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||
except ImportError:
|
||||
w8a8_block_fp8_matmul = None
|
||||
|
@ -8,13 +8,14 @@ from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
permute_scales,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
try:
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
else:
|
||||
marlin_kernels = None
|
||||
|
||||
|
||||
|
@ -16,13 +16,14 @@ from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
else:
|
||||
marlin_kernels = None
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
|
@ -5,14 +5,15 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
else:
|
||||
marlin_kernels = None
|
||||
|
||||
|
||||
|
@ -6,11 +6,11 @@ import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
|
||||
try:
|
||||
if SYSTEM == "cuda":
|
||||
marlin_kernels = load_kernel(
|
||||
module="quantization", repo_id="kernels-community/quantization"
|
||||
)
|
||||
except ImportError:
|
||||
else:
|
||||
marlin_kernels = None
|
||||
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user