tp monkey patch

This commit is contained in:
Cyril Vallez 2025-01-24 15:03:14 +01:00
parent 6cb41a80a1
commit de83178bc3
No known key found for this signature in database
2 changed files with 71 additions and 3 deletions

View File

@ -848,7 +848,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
else: else:
return TransformersFlashCausalLM.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -898,7 +898,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE: elif model_type == LLAMA or model_type == GRANITE:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -934,7 +934,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == BAICHUAN: elif model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
@ -1433,6 +1433,7 @@ def get_model(
FLASH_TRANSFORMERS_BACKEND FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None and transformers_model_class is not None
and transformers_model_class._supports_flex_attn and transformers_model_class._supports_flex_attn
and hasattr(transformers_model_class.config_class, "base_model_tp_plan")
): ):
return TransformersFlashCausalLM.fallback( return TransformersFlashCausalLM.fallback(
model_id, model_id,

View File

@ -82,6 +82,73 @@ def tgi_flash_attention_forward(
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
# Those are actually missing and Transformers so hardcoded here for now!
transformers.models.olmo.configuration_olmo.OlmoConfig.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.olmo2.configuration_olmo2.Olmo2Config.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.cohere.configuration_cohere.CohereConfig.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.cohere.configuration_cohere2.Cohere2Config.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.gemma.configuration_gemma.GemmaConfig.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.helium.configuration_helium.HeliumConfig.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.mixtral.configuration_mixtral.MixtralConfig.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.block_sparse_moe.gate": "colwise",
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
}
class TransformersFlashCausalLM(FlashCausalLM): class TransformersFlashCausalLM(FlashCausalLM):
def __init__( def __init__(
self, self,