From 32488c1a11f7593490bcfbe000fb4e1fc6c89400 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 17 Jan 2025 12:26:51 +0000 Subject: [PATCH] remove flag --- .../text_generation_server/models/__init__.py | 70 ++++++++----------- .../text_generation_server/models/globals.py | 3 - 2 files changed, 28 insertions(+), 45 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c97b0006..66be0be2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -30,7 +30,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) -from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING +from text_generation_server.models.globals import ATTENTION from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -368,7 +368,6 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION - global USE_CUSTOM_MODELING config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code @@ -376,24 +375,11 @@ def get_model( model_type = config_dict.get("model_type", None) transformers_causal_lm_class = CausalLM - if not USE_CUSTOM_MODELING: - logger.info( - "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." - ) - transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) - - if transformers_model_class._supports_flex_attn: - logger.info( - f"Transformers' {model_type} implementation supports ragged tensors format (single dimension for " - "batch and sequence length). All TGI's batching/caching optimizations are enabled." - ) - transformers_causal_lm_class = TransformersFlashCausalLM - else: - logger.info( - f"Transformers' {model_type} implementation does not supports ragged tensors format. Will use classic " - "format with padding (two dimensions for batch size and sequence length). This is expected to be slow." - ) + # Fast transformers path + transformers_model_class = getattr(transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]) + if transformers_model_class._supports_flex_attn: + transformers_causal_lm_class = TransformersFlashCausalLM quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: @@ -613,7 +599,7 @@ def get_model( ) if model_type == DEEPSEEK_V2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: head_size = max( config_dict.get("qk_nope_dim", 128) + config_dict.get("qk_rope_dim", 64), @@ -678,7 +664,7 @@ def get_model( or model_type == GPT2 and model_id.startswith("bigcode/") ): - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashSantacoderForCausalLM, @@ -729,7 +715,7 @@ def get_model( batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: try: return FlashCausalLM( model_id=model_id, @@ -765,7 +751,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPTJ: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: try: return FlashCausalLM( model_id=model_id, @@ -801,7 +787,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPT_NEOX: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( GPTNeoXConfig, ) @@ -839,7 +825,7 @@ def get_model( ) elif model_type == PHI: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashPhiForCausalLM, @@ -862,7 +848,7 @@ def get_model( ) elif model_type == PHI_MOE: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -886,7 +872,7 @@ def get_model( ) elif model_type == "phi-msft": - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: raise NotImplementedError( "Legacy phi-msft is not supported with Flash Attention" ) @@ -908,7 +894,7 @@ def get_model( or model_type == PHI3 or model_type == GRANITE ): - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -934,7 +920,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == GEMMA: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashGemmaForCausalLM, @@ -960,7 +946,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GEMMA2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashGemma2ForCausalLM, @@ -987,7 +973,7 @@ def get_model( ) if model_type == COHERE: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashCohereForCausalLM, @@ -1012,7 +998,7 @@ def get_model( ) if model_type == DBRX: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashDbrxForCausalLM, @@ -1041,7 +1027,7 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashCausalLM( @@ -1090,7 +1076,7 @@ def get_model( ) if model_type == MISTRAL: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashMistralForCausalLM, @@ -1115,7 +1101,7 @@ def get_model( ) if model_type == MIXTRAL: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashMixtralForCausalLM, @@ -1140,7 +1126,7 @@ def get_model( ) if model_type == STARCODER2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=FlashStarcoder2ForCausalLM, @@ -1167,7 +1153,7 @@ def get_model( ) if model_type == QWEN2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, model_class=Qwen2ForCausalLM, @@ -1219,7 +1205,7 @@ def get_model( }, ) if model_type == IDEFICS: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return IdeficsCausalLM( model_id, revision, @@ -1243,7 +1229,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) if model_type == MLLAMA: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return MllamaCausalLM( model_id=model_id, model_class=MllamaForConditionalGeneration, @@ -1259,7 +1245,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) if model_type == IDEFICS2: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, model_class=Idefics2ForConditionalGeneration, @@ -1277,7 +1263,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, model_class=PaliGemmaForConditionalGeneration, @@ -1296,7 +1282,7 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: - if FLASH_ATTENTION and USE_CUSTOM_MODELING: + if FLASH_ATTENTION: return VlmCausalLM( model_class=LlavaNextForConditionalGeneration, model_id=model_id, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 89f920bb..8a33fb32 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -68,6 +68,3 @@ def get_adapter_to_index(): global ADAPTER_TO_INDEX return ADAPTER_TO_INDEX - -USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true") -USE_CUSTOM_MODELING = USE_CUSTOM_MODELING.lower() == "true" or USE_CUSTOM_MODELING == "1"