This commit is contained in:
Cyril Vallez 2024-12-10 16:57:07 +01:00
parent ade0f44aca
commit da222900a1
No known key found for this signature in database
2 changed files with 57 additions and 44 deletions

View File

@ -16,6 +16,7 @@ from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict
from pathlib import Path
import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
@ -617,7 +618,7 @@ def get_model(
)
if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
@ -642,7 +643,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -682,7 +683,7 @@ def get_model(
or model_type == GPT2
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
@ -701,7 +702,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id=model_id,
revision=revision,
quantize=quantize,
@ -733,7 +734,7 @@ def get_model(
batch_class=CausalLMBatchKeysLast,
)
elif model_type == GPT2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
try:
return FlashCausalLM(
model_id=model_id,
@ -749,7 +750,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -760,7 +761,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -769,7 +770,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GPTJ:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
try:
return FlashCausalLM(
model_id=model_id,
@ -785,7 +786,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -796,7 +797,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -805,7 +806,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GPT_NEOX:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
@ -833,7 +834,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -843,7 +844,7 @@ def get_model(
)
elif model_type == PHI:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
@ -856,7 +857,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -866,7 +867,7 @@ def get_model(
)
elif model_type == PHI_MOE:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
@ -880,7 +881,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -890,7 +891,7 @@ def get_model(
)
elif model_type == "phi-msft":
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention"
)
@ -912,7 +913,7 @@ def get_model(
or model_type == PHI3
or model_type == GRANITE
):
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
@ -929,7 +930,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -938,7 +939,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == GEMMA:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
@ -955,7 +956,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -964,7 +965,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
@ -981,7 +982,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -991,7 +992,7 @@ def get_model(
)
if model_type == COHERE:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
@ -1006,7 +1007,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1016,7 +1017,7 @@ def get_model(
)
if model_type == DBRX:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
@ -1034,7 +1035,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1045,7 +1046,7 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashCausalLM(
@ -1084,7 +1085,7 @@ def get_model(
config_class=RWConfig,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1094,7 +1095,7 @@ def get_model(
)
if model_type == MISTRAL:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
@ -1109,7 +1110,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1119,7 +1120,7 @@ def get_model(
)
if model_type == MIXTRAL:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
@ -1134,7 +1135,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1144,7 +1145,7 @@ def get_model(
)
if model_type == STARCODER2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
@ -1161,7 +1162,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1171,7 +1172,7 @@ def get_model(
)
if model_type == QWEN2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
@ -1186,7 +1187,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1223,7 +1224,7 @@ def get_model(
},
)
if model_type == IDEFICS:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return IdeficsCausalLM(
model_id,
revision,
@ -1247,7 +1248,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
@ -1263,7 +1264,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return VlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
@ -1281,7 +1282,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return VlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
@ -1300,7 +1301,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
@ -1329,7 +1330,7 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1350,7 +1351,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,

View File

@ -188,6 +188,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
device=device,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code)
def warmup(self, batch: FlashCausalLMBatch):
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
super().warmup(batch)