mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixup some imports
This commit is contained in:
parent
a60d1e614f
commit
f25a7aad89
@ -6,15 +6,16 @@ import torch
|
|||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
marlin_kernels = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
except ImportError:
|
else:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,14 +15,15 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
try:
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
marlin_kernels = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
except ImportError:
|
else:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# TODO: needs to be ported over to MoE and used on CUDA.
|
||||||
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||||
except ImportError:
|
except ImportError:
|
||||||
w8a8_block_fp8_matmul = None
|
w8a8_block_fp8_matmul = None
|
||||||
|
@ -8,13 +8,14 @@ from text_generation_server.layers.marlin.util import (
|
|||||||
_check_marlin_kernels,
|
_check_marlin_kernels,
|
||||||
permute_scales,
|
permute_scales,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
try:
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
marlin_kernels = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
except ImportError:
|
else:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,13 +16,14 @@ from text_generation_server.utils.kernels import load_kernel
|
|||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
marlin_kernels = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
except ImportError:
|
else:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
major, _minor = torch.cuda.get_device_capability()
|
||||||
has_sm_8_0 = major >= 8
|
has_sm_8_0 = major >= 8
|
||||||
|
@ -5,14 +5,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
try:
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
marlin_kernels = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
except ImportError:
|
else:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,11 +6,11 @@ import torch
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
|
|
||||||
try:
|
if SYSTEM == "cuda":
|
||||||
marlin_kernels = load_kernel(
|
marlin_kernels = load_kernel(
|
||||||
module="quantization", repo_id="kernels-community/quantization"
|
module="quantization", repo_id="kernels-community/quantization"
|
||||||
)
|
)
|
||||||
except ImportError:
|
else:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user