From 2ef3002c2b1ec9e3c3347fd3ae61c87db00b6266 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 20 Jan 2025 16:37:41 +0100 Subject: [PATCH] Update __init__.py --- .../text_generation_server/models/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index cfe9d025..a7c8c4a7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -83,9 +83,6 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.transformers_flash_causal_lm import ( - TransformersFlashCausalLM, - ) from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( @@ -180,6 +177,14 @@ except ImportError as e: if MAMBA_AVAILABLE: __all__.append(Mamba) +FLASH_TRANSFORMERS_BACKEND = True +try: + from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, + ) +except ImportError: + FLASH_TRANSFORMERS_BACKEND = False + class ModelType(enum.Enum): DEEPSEEK_V2 = { @@ -384,12 +389,8 @@ def get_model( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - if FLASH_ATTENTION and transformers_model_class._supports_flex_attn: + if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if lora_adapter_ids is not None and len(lora_adapter_ids) > 0: - raise ValueError( - "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." - ) quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: