Update transformers_flash_causal_lm.py

This commit is contained in:
Cyril Vallez 2025-01-24 15:06:50 +01:00
parent de83178bc3
commit bafbd06744
No known key found for this signature in database

View File

@ -110,7 +110,7 @@ transformers.models.cohere.configuration_cohere.CohereConfig.base_model_tp_plan
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
transformers.models.cohere.configuration_cohere2.Cohere2Config.base_model_tp_plan = {
transformers.models.cohere2.configuration_cohere2.Cohere2Config.base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",