Update __init__.py

This commit is contained in:
Cyril Vallez 2025-01-20 16:37:41 +01:00
parent 6d9c011f51
commit 2ef3002c2b
No known key found for this signature in database

View File

@ -83,9 +83,6 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
@ -180,6 +177,14 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = True
try:
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
except ImportError:
FLASH_TRANSFORMERS_BACKEND = False
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = { DEEPSEEK_V2 = {
@ -384,12 +389,8 @@ def get_model(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
) )
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn: if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn:
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
raise ValueError(
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
)
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None: if quantization_config is None: