mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Update __init__.py
This commit is contained in:
parent
6d9c011f51
commit
2ef3002c2b
@ -83,9 +83,6 @@ FLASH_ATTENTION = True
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.transformers_flash_causal_lm import (
|
|
||||||
TransformersFlashCausalLM,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
@ -180,6 +177,14 @@ except ImportError as e:
|
|||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
|
|
||||||
|
FLASH_TRANSFORMERS_BACKEND = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.transformers_flash_causal_lm import (
|
||||||
|
TransformersFlashCausalLM,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
FLASH_TRANSFORMERS_BACKEND = False
|
||||||
|
|
||||||
|
|
||||||
class ModelType(enum.Enum):
|
class ModelType(enum.Enum):
|
||||||
DEEPSEEK_V2 = {
|
DEEPSEEK_V2 = {
|
||||||
@ -384,12 +389,8 @@ def get_model(
|
|||||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||||
)
|
)
|
||||||
|
|
||||||
if FLASH_ATTENTION and transformers_model_class._supports_flex_attn:
|
if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn:
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
if lora_adapter_ids is not None and len(lora_adapter_ids) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Flash `Transformers` modeling backend does not support `lora_adapter_ids`."
|
|
||||||
)
|
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user