move the import to avoid device issue

This commit is contained in:
Cyril Vallez 2025-01-20 16:11:41 +01:00
parent 9af3ea4b70
commit 6d9c011f51
No known key found for this signature in database

View File

@ -21,9 +21,7 @@ import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
@ -85,6 +83,9 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
@ -382,16 +383,10 @@ def get_model(
transformers_model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
)
accelerator_available = torch.cuda.is_available() or (
hasattr(torch, "xpu") and torch.xpu.is_available()
)
if transformers_model_class._supports_flex_attn and accelerator_available:
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn:
transformers_causal_lm_class = TransformersFlashCausalLM
if (
not FLASH_ATTENTION
and lora_adapter_ids is not None
and len(lora_adapter_ids) > 0
):
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
raise ValueError(
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
)