This commit is contained in:
Cyril Vallez 2024-12-10 16:57:07 +01:00
parent ade0f44aca
commit da222900a1
No known key found for this signature in database
2 changed files with 57 additions and 44 deletions

View File

@ -16,6 +16,7 @@ from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
@ -617,7 +618,7 @@ def get_model(
) )
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
head_size = max( head_size = max(
config_dict.get("qk_nope_dim", 128) config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64), + config_dict.get("qk_rope_dim", 64),
@ -642,7 +643,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -682,7 +683,7 @@ def get_model(
or model_type == GPT2 or model_type == GPT2
and model_id.startswith("bigcode/") and model_id.startswith("bigcode/")
): ):
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashSantacoderForCausalLM, model_class=FlashSantacoderForCausalLM,
@ -701,7 +702,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
@ -733,7 +734,7 @@ def get_model(
batch_class=CausalLMBatchKeysLast, batch_class=CausalLMBatchKeysLast,
) )
elif model_type == GPT2: elif model_type == GPT2:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
try: try:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -749,7 +750,7 @@ def get_model(
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -760,7 +761,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -769,7 +770,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GPTJ: elif model_type == GPTJ:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
try: try:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -785,7 +786,7 @@ def get_model(
except RuntimeError as e: except RuntimeError as e:
# Lots of legacy models with various weight names. # Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -796,7 +797,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -805,7 +806,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GPT_NEOX: elif model_type == GPT_NEOX:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig, GPTNeoXConfig,
) )
@ -833,7 +834,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -843,7 +844,7 @@ def get_model(
) )
elif model_type == PHI: elif model_type == PHI:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashPhiForCausalLM, model_class=FlashPhiForCausalLM,
@ -856,7 +857,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -866,7 +867,7 @@ def get_model(
) )
elif model_type == PHI_MOE: elif model_type == PHI_MOE:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashLlamaForCausalLM, model_class=FlashLlamaForCausalLM,
@ -880,7 +881,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -890,7 +891,7 @@ def get_model(
) )
elif model_type == "phi-msft": elif model_type == "phi-msft":
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
raise NotImplementedError( raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention" "Legacy phi-msft is not supported with Flash Attention"
) )
@ -912,7 +913,7 @@ def get_model(
or model_type == PHI3 or model_type == PHI3
or model_type == GRANITE or model_type == GRANITE
): ):
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashLlamaForCausalLM, model_class=FlashLlamaForCausalLM,
@ -929,7 +930,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -938,7 +939,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == GEMMA: if model_type == GEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashGemmaForCausalLM, model_class=FlashGemmaForCausalLM,
@ -955,7 +956,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -964,7 +965,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GEMMA2: elif model_type == GEMMA2:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashGemma2ForCausalLM, model_class=FlashGemma2ForCausalLM,
@ -981,7 +982,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -991,7 +992,7 @@ def get_model(
) )
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashCohereForCausalLM, model_class=FlashCohereForCausalLM,
@ -1006,7 +1007,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1016,7 +1017,7 @@ def get_model(
) )
if model_type == DBRX: if model_type == DBRX:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashDbrxForCausalLM, model_class=FlashDbrxForCausalLM,
@ -1034,7 +1035,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1045,7 +1046,7 @@ def get_model(
if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if config_dict.get("alibi", False): if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashCausalLM( return FlashCausalLM(
@ -1084,7 +1085,7 @@ def get_model(
config_class=RWConfig, config_class=RWConfig,
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1094,7 +1095,7 @@ def get_model(
) )
if model_type == MISTRAL: if model_type == MISTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashMistralForCausalLM, model_class=FlashMistralForCausalLM,
@ -1109,7 +1110,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1119,7 +1120,7 @@ def get_model(
) )
if model_type == MIXTRAL: if model_type == MIXTRAL:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashMixtralForCausalLM, model_class=FlashMixtralForCausalLM,
@ -1134,7 +1135,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1144,7 +1145,7 @@ def get_model(
) )
if model_type == STARCODER2: if model_type == STARCODER2:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=FlashStarcoder2ForCausalLM, model_class=FlashStarcoder2ForCausalLM,
@ -1161,7 +1162,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
) )
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1171,7 +1172,7 @@ def get_model(
) )
if model_type == QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2ForCausalLM, model_class=Qwen2ForCausalLM,
@ -1186,7 +1187,7 @@ def get_model(
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else: else:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1223,7 +1224,7 @@ def get_model(
}, },
) )
if model_type == IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return IdeficsCausalLM( return IdeficsCausalLM(
model_id, model_id,
revision, revision,
@ -1247,7 +1248,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
if model_type == MLLAMA: if model_type == MLLAMA:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return MllamaCausalLM( return MllamaCausalLM(
model_id=model_id, model_id=model_id,
model_class=MllamaForConditionalGeneration, model_class=MllamaForConditionalGeneration,
@ -1263,7 +1264,7 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Idefics2ForConditionalGeneration, model_class=Idefics2ForConditionalGeneration,
@ -1281,7 +1282,7 @@ def get_model(
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA: if model_type == PALIGEMMA:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return VlmCausalLM( return VlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=PaliGemmaForConditionalGeneration, model_class=PaliGemmaForConditionalGeneration,
@ -1300,7 +1301,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT: if model_type == LLAVA_NEXT:
if FLASH_ATTENTION: if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return VlmCausalLM( return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration, model_class=LlavaNextForConditionalGeneration,
model_id=model_id, model_id=model_id,
@ -1329,7 +1330,7 @@ def get_model(
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -1350,7 +1351,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,

View File

@ -188,6 +188,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
device=device, device=device,
) )
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
return cls(model_id, revision, quantize, speculator, dtype, trust_remote_code)
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
super().warmup(batch) super().warmup(batch)