mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
inits
This commit is contained in:
parent
ade0f44aca
commit
da222900a1
@ -16,6 +16,7 @@ from transformers.models.auto import modeling_auto
|
|||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import transformers
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
@ -617,7 +618,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
head_size = max(
|
head_size = max(
|
||||||
config_dict.get("qk_nope_dim", 128)
|
config_dict.get("qk_nope_dim", 128)
|
||||||
+ config_dict.get("qk_rope_dim", 64),
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
@ -642,7 +643,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -682,7 +683,7 @@ def get_model(
|
|||||||
or model_type == GPT2
|
or model_type == GPT2
|
||||||
and model_id.startswith("bigcode/")
|
and model_id.startswith("bigcode/")
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashSantacoderForCausalLM,
|
model_class=FlashSantacoderForCausalLM,
|
||||||
@ -701,7 +702,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -733,7 +734,7 @@ def get_model(
|
|||||||
batch_class=CausalLMBatchKeysLast,
|
batch_class=CausalLMBatchKeysLast,
|
||||||
)
|
)
|
||||||
elif model_type == GPT2:
|
elif model_type == GPT2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
try:
|
try:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -749,7 +750,7 @@ def get_model(
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -760,7 +761,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -769,7 +770,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == GPTJ:
|
elif model_type == GPTJ:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
try:
|
try:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -785,7 +786,7 @@ def get_model(
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Lots of legacy models with various weight names.
|
# Lots of legacy models with various weight names.
|
||||||
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
|
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -796,7 +797,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -805,7 +806,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == GPT_NEOX:
|
elif model_type == GPT_NEOX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||||
GPTNeoXConfig,
|
GPTNeoXConfig,
|
||||||
)
|
)
|
||||||
@ -833,7 +834,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -843,7 +844,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == PHI:
|
elif model_type == PHI:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashPhiForCausalLM,
|
model_class=FlashPhiForCausalLM,
|
||||||
@ -856,7 +857,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -866,7 +867,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == PHI_MOE:
|
elif model_type == PHI_MOE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashLlamaForCausalLM,
|
model_class=FlashLlamaForCausalLM,
|
||||||
@ -880,7 +881,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -890,7 +891,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "phi-msft":
|
elif model_type == "phi-msft":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Legacy phi-msft is not supported with Flash Attention"
|
"Legacy phi-msft is not supported with Flash Attention"
|
||||||
)
|
)
|
||||||
@ -912,7 +913,7 @@ def get_model(
|
|||||||
or model_type == PHI3
|
or model_type == PHI3
|
||||||
or model_type == GRANITE
|
or model_type == GRANITE
|
||||||
):
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashLlamaForCausalLM,
|
model_class=FlashLlamaForCausalLM,
|
||||||
@ -929,7 +930,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -938,7 +939,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type == GEMMA:
|
if model_type == GEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashGemmaForCausalLM,
|
model_class=FlashGemmaForCausalLM,
|
||||||
@ -955,7 +956,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -964,7 +965,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == GEMMA2:
|
elif model_type == GEMMA2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashGemma2ForCausalLM,
|
model_class=FlashGemma2ForCausalLM,
|
||||||
@ -981,7 +982,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -991,7 +992,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashCohereForCausalLM,
|
model_class=FlashCohereForCausalLM,
|
||||||
@ -1006,7 +1007,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1016,7 +1017,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == DBRX:
|
if model_type == DBRX:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashDbrxForCausalLM,
|
model_class=FlashDbrxForCausalLM,
|
||||||
@ -1034,7 +1035,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1045,7 +1046,7 @@ def get_model(
|
|||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
if config_dict.get("alibi", False):
|
if config_dict.get("alibi", False):
|
||||||
raise NotImplementedError("sharded is not supported for this model")
|
raise NotImplementedError("sharded is not supported for this model")
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
@ -1084,7 +1085,7 @@ def get_model(
|
|||||||
config_class=RWConfig,
|
config_class=RWConfig,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1094,7 +1095,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashMistralForCausalLM,
|
model_class=FlashMistralForCausalLM,
|
||||||
@ -1109,7 +1110,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1119,7 +1120,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashMixtralForCausalLM,
|
model_class=FlashMixtralForCausalLM,
|
||||||
@ -1134,7 +1135,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1144,7 +1145,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=FlashStarcoder2ForCausalLM,
|
model_class=FlashStarcoder2ForCausalLM,
|
||||||
@ -1161,7 +1162,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1171,7 +1172,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Qwen2ForCausalLM,
|
model_class=Qwen2ForCausalLM,
|
||||||
@ -1186,7 +1187,7 @@ def get_model(
|
|||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||||
else:
|
else:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1223,7 +1224,7 @@ def get_model(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if model_type == IDEFICS:
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return IdeficsCausalLM(
|
return IdeficsCausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -1247,7 +1248,7 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return MllamaCausalLM(
|
return MllamaCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=MllamaForConditionalGeneration,
|
model_class=MllamaForConditionalGeneration,
|
||||||
@ -1263,7 +1264,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=Idefics2ForConditionalGeneration,
|
model_class=Idefics2ForConditionalGeneration,
|
||||||
@ -1281,7 +1282,7 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_class=PaliGemmaForConditionalGeneration,
|
model_class=PaliGemmaForConditionalGeneration,
|
||||||
@ -1300,7 +1301,7 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == LLAVA_NEXT:
|
if model_type == LLAVA_NEXT:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
model_class=LlavaNextForConditionalGeneration,
|
model_class=LlavaNextForConditionalGeneration,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -1329,7 +1330,7 @@ def get_model(
|
|||||||
elif quantize == "exl2":
|
elif quantize == "exl2":
|
||||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -1350,7 +1351,7 @@ def get_model(
|
|||||||
auto_map = config_dict.get("auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
if trust_remote_code and auto_map is not None:
|
if trust_remote_code and auto_map is not None:
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
return CausalLM.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -188,6 +188,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fallback(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code)
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
||||||
super().warmup(batch)
|
super().warmup(batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user