mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Update __init__.py
This commit is contained in:
parent
6d9c011f51
commit
2ef3002c2b
@ -83,9 +83,6 @@ FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||
TransformersFlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||
@ -180,6 +177,14 @@ except ImportError as e:
|
||||
if MAMBA_AVAILABLE:
|
||||
__all__.append(Mamba)
|
||||
|
||||
FLASH_TRANSFORMERS_BACKEND = True
|
||||
try:
|
||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||
TransformersFlashCausalLM,
|
||||
)
|
||||
except ImportError:
|
||||
FLASH_TRANSFORMERS_BACKEND = False
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
DEEPSEEK_V2 = {
|
||||
@ -384,12 +389,8 @@ def get_model(
|
||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||
)
|
||||
|
||||
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn:
|
||||
if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn:
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
|
||||
raise ValueError(
|
||||
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
|
||||
)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
|
Loading…
Reference in New Issue
Block a user