diff --git a/server/text_generation_server/layers/moe/fp8.py b/server/text_generation_server/layers/moe/fp8.py index 7ccddb5b..4d0295f6 100644 --- a/server/text_generation_server/layers/moe/fp8.py +++ b/server/text_generation_server/layers/moe/fp8.py @@ -10,7 +10,11 @@ from text_generation_server.layers.fp8 import ( quant_dtype, normalize_e4m3fn_to_native_float8, ) -from moe_kernels.fused_moe import fused_moe + +try: + from moe_kernels.fused_moe import fused_moe +except Exception: + fused_moe = None class FP8SparseMoELayer(nn.Module): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index eb5a8de7..2d735227 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -180,7 +180,7 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) -FLASH_TRANSFORMERS_BACKEND = True +FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() try: from text_generation_server.models.transformers_flash_causal_lm import ( TransformersFlashCausalLM,