diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index be035417..7513ce01 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -110,7 +110,7 @@ transformers.models.cohere.configuration_cohere.CohereConfig.base_model_tp_plan "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } -transformers.models.cohere.configuration_cohere2.Cohere2Config.base_model_tp_plan = { +transformers.models.cohere2.configuration_cohere2.Cohere2Config.base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise",