This commit is contained in:
Cyril Vallez 2024-12-13 14:02:45 +00:00
parent 715b2d19ed
commit e93ab925f9

View File

@ -12,7 +12,7 @@ import os
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto, modeling_task
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
@ -375,30 +375,26 @@ def get_model(
) )
model_type = config_dict.get("model_type", None) model_type = config_dict.get("model_type", None)
# transformers_causal_lm_class = CausalLM transformers_causal_lm_class = CausalLM
transformers_causal_lm_class = TransformersFlashCausalLM if not USE_CUSTOM_MODELING:
if (
not USE_CUSTOM_MODELING
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
logger.info( logger.info(
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
) )
transformers_model_class = getattr( try:
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type])
) except KeyError:
transformers_model_class = modeling_task.AutoForCausalLM
if ( if transformers_model_class._supports_flash_attn_2:
transformers_model_class._supports_flash_attn_2
and transformers_model_class._supports_cache_class
):
logger.info( logger.info(
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)." f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for "
"batch and sequence length). All TGI's batching/caching optimizations are enabled."
) )
transformers_causal_lm_class = TransformersFlashCausalLM transformers_causal_lm_class = TransformersFlashCausalLM
else: else:
logger.info( logger.info(
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)." f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic "
"format with padding (two dimensions for batch size and sequence length). This is expected to be slow."
) )
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)