From e93ab925f921c2119702ce83c5ab8d3420f81a83 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 13 Dec 2024 14:02:45 +0000 Subject: [PATCH] init --- .../text_generation_server/models/__init__.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 54665083..bf481e29 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,7 @@ import os from loguru import logger from transformers.configuration_utils import PretrainedConfig -from transformers.models.auto import modeling_auto +from transformers.models.auto import modeling_auto, modeling_task from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path @@ -375,30 +375,26 @@ def get_model( ) model_type = config_dict.get("model_type", None) - # transformers_causal_lm_class = CausalLM - transformers_causal_lm_class = TransformersFlashCausalLM - if ( - not USE_CUSTOM_MODELING - and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ): + transformers_causal_lm_class = CausalLM + if not USE_CUSTOM_MODELING: logger.info( "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." ) - transformers_model_class = getattr( - transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] - ) + try: + transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) + except KeyError: + transformers_model_class = modeling_task.AutoForCausalLM - if ( - transformers_model_class._supports_flash_attn_2 - and transformers_model_class._supports_cache_class - ): + if transformers_model_class._supports_flash_attn_2: logger.info( - f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)." + f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " + "batch and sequence length). All TGI's batching/caching optimizations are enabled." ) transformers_causal_lm_class = TransformersFlashCausalLM else: logger.info( - f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)." + f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic " + "format with padding (two dimensions for batch size and sequence length). This is expected to be slow." ) quantization_config = config_dict.get("quantization_config", None)