mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
move the import to avoid device issue
This commit is contained in:
parent
9af3ea4b70
commit
6d9c011f51
@ -21,9 +21,7 @@ import transformers
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||
TransformersFlashCausalLM,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
@ -85,6 +83,9 @@ FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||
TransformersFlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||
@ -382,16 +383,10 @@ def get_model(
|
||||
transformers_model_class = getattr(
|
||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||
)
|
||||
accelerator_available = torch.cuda.is_available() or (
|
||||
hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
)
|
||||
if transformers_model_class._supports_flex_attn and accelerator_available:
|
||||
|
||||
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn:
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
if (
|
||||
not FLASH_ATTENTION
|
||||
and lora_adapter_ids is not None
|
||||
and len(lora_adapter_ids) > 0
|
||||
):
|
||||
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
|
||||
raise ValueError(
|
||||
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user