mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
move the import to avoid device issue
This commit is contained in:
parent
9af3ea4b70
commit
6d9c011f51
@ -21,9 +21,7 @@ import transformers
|
|||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
|
||||||
TransformersFlashCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
MPTForCausalLM,
|
MPTForCausalLM,
|
||||||
@ -85,6 +83,9 @@ FLASH_ATTENTION = True
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
|
TransformersFlashCausalLM,
|
||||||
|
)
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
@ -382,16 +383,10 @@ def get_model(
|
|||||||
transformers_model_class = getattr(
|
transformers_model_class = getattr(
|
||||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||||
)
|
)
|
||||||
accelerator_available = torch.cuda.is_available() or (
|
|
||||||
hasattr(torch, "xpu") and torch.xpu.is_available()
|
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn:
|
||||||
)
|
|
||||||
if transformers_model_class._supports_flex_attn and accelerator_available:
|
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
if (
|
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
|
||||||
not FLASH_ATTENTION
|
|
||||||
and lora_adapter_ids is not None
|
|
||||||
and len(lora_adapter_ids) > 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
|
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user