mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
check for non-native models
This commit is contained in:
parent
2ef3002c2b
commit
70ada578b9
@ -16,7 +16,6 @@ from transformers.models.auto import modeling_auto
|
|||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import transformers
|
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
@ -385,11 +384,14 @@ def get_model(
|
|||||||
transformers_causal_lm_class = CausalLM
|
transformers_causal_lm_class = CausalLM
|
||||||
|
|
||||||
# Fast transformers path
|
# Fast transformers path
|
||||||
transformers_model_class = getattr(
|
transformers_model_class = modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(
|
||||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
model_type, None
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
if FLASH_TRANSFORMERS_BACKEND and transformers_model_class._supports_flex_attn:
|
FLASH_TRANSFORMERS_BACKEND
|
||||||
|
and transformers_model_class is not None
|
||||||
|
and transformers_model_class._supports_flex_attn
|
||||||
|
):
|
||||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
Loading…
Reference in New Issue
Block a user