diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 612ad8b3..cfe9d025 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -21,9 +21,7 @@ import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast -from text_generation_server.models.transformers_flash_causal_lm import ( - TransformersFlashCausalLM, -) + from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -85,6 +83,9 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM + from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, + ) from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( @@ -382,16 +383,10 @@ def get_model( transformers_model_class = getattr( transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] ) - accelerator_available = torch.cuda.is_available() or ( - hasattr(torch, "xpu") and torch.xpu.is_available() - ) - if transformers_model_class._supports_flex_attn and accelerator_available: + + if FLASH_ATTENTION and transformers_model_class._supports_flex_attn: transformers_causal_lm_class = TransformersFlashCausalLM - if ( - not FLASH_ATTENTION - and lora_adapter_ids is not None - and len(lora_adapter_ids) > 0 - ): + if lora_adapter_ids is not None and len(lora_adapter_ids) > 0: raise ValueError( "Flash `Transformers` modeling backend does not support `lora_adapter_ids`." )