diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index cfe9d025..a7c8c4a7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -83,9 +83,6 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.transformers_flash_causal_lm import ( - TransformersFlashCausalLM, - ) from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( @@ -180,6 +177,14 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) +FLASH_TRANSFORMERS_BACKEND = True +try: + from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, + ) +except ImportError: + FLASH_TRANSFORMERS_BACKEND = False + class ModelType(enum.Enum): DEEPSEEK_V2 = { @@ -384,12 +389,8 @@ def get_model( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - if FLASH_ATTENTION and transformers_model_class._supports_flex_attn: + if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if lora_adapter_ids is not None and len(lora_adapter_ids) > 0: - raise ValueError( - "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." - ) quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: