diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2f3ccc2d..35ab8ede 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List, Dict from pathlib import Path +import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -617,7 +618,7 @@ def get_model( ) if model_type == DEEPSEEK_V2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: head_size = max( config_dict.get("qk_nope_dim", 128) + config_dict.get("qk_rope_dim", 64), @@ -642,7 +643,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -682,7 +683,7 @@ def get_model( or model_type == GPT2 and model_id.startswith("bigcode/") ): - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashSantacoderForCausalLM, @@ -701,7 +702,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id=model_id, revision=revision, quantize=quantize, @@ -733,7 +734,7 @@ def get_model( batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: try: return FlashCausalLM( model_id=model_id, @@ -749,7 +750,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -760,7 +761,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -769,7 +770,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPTJ: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: try: return FlashCausalLM( model_id=model_id, @@ -785,7 +786,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -796,7 +797,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -805,7 +806,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPT_NEOX: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( GPTNeoXConfig, ) @@ -833,7 +834,7 @@ def get_model( trust_remote_code=trust_remote_code, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -843,7 +844,7 @@ def get_model( ) elif model_type == PHI: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashPhiForCausalLM, @@ -856,7 +857,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -866,7 +867,7 @@ def get_model( ) elif model_type == PHI_MOE: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -880,7 +881,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -890,7 +891,7 @@ def get_model( ) elif model_type == "phi-msft": - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: raise NotImplementedError( "Legacy phi-msft is not supported with Flash Attention" ) @@ -912,7 +913,7 @@ def get_model( or model_type == PHI3 or model_type == GRANITE ): - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashLlamaForCausalLM, @@ -929,7 +930,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -938,7 +939,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == GEMMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashGemmaForCausalLM, @@ -955,7 +956,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -964,7 +965,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GEMMA2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashGemma2ForCausalLM, @@ -981,7 +982,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -991,7 +992,7 @@ def get_model( ) if model_type == COHERE: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashCohereForCausalLM, @@ -1006,7 +1007,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1016,7 +1017,7 @@ def get_model( ) if model_type == DBRX: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashDbrxForCausalLM, @@ -1034,7 +1035,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1045,7 +1046,7 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashCausalLM( @@ -1084,7 +1085,7 @@ def get_model( config_class=RWConfig, ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1094,7 +1095,7 @@ def get_model( ) if model_type == MISTRAL: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashMistralForCausalLM, @@ -1109,7 +1110,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1119,7 +1120,7 @@ def get_model( ) if model_type == MIXTRAL: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashMixtralForCausalLM, @@ -1134,7 +1135,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1144,7 +1145,7 @@ def get_model( ) if model_type == STARCODER2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=FlashStarcoder2ForCausalLM, @@ -1161,7 +1162,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1171,7 +1172,7 @@ def get_model( ) if model_type == QWEN2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCausalLM( model_id=model_id, model_class=Qwen2ForCausalLM, @@ -1186,7 +1187,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1223,7 +1224,7 @@ def get_model( }, ) if model_type == IDEFICS: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return IdeficsCausalLM( model_id, revision, @@ -1247,7 +1248,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) if model_type == MLLAMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return MllamaCausalLM( model_id=model_id, model_class=MllamaForConditionalGeneration, @@ -1263,7 +1264,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) if model_type == IDEFICS2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return VlmCausalLM( model_id=model_id, model_class=Idefics2ForConditionalGeneration, @@ -1281,7 +1282,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == PALIGEMMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return VlmCausalLM( model_id=model_id, model_class=PaliGemmaForConditionalGeneration, @@ -1300,7 +1301,7 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return VlmCausalLM( model_class=LlavaNextForConditionalGeneration, model_id=model_id, @@ -1329,7 +1330,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, @@ -1350,7 +1351,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM.fallback( + return transformers_causal_lm_class.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index ff76b2cc..de2570b0 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -188,6 +188,18 @@ class TransformersFlashCausalLM(FlashCausalLM): device=device, ) + @classmethod + def fallback( + cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code) + def warmup(self, batch: FlashCausalLMBatch): patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) super().warmup(batch)