diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index eb5a8de7..f14cab43 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -848,7 +848,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return TransformersFlashCausalLM.fallback( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -898,7 +898,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE: + elif model_type == LLAMA or model_type == GRANITE: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -934,7 +934,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == BAICHUAN: + elif model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: return FlashCausalLM( model_id=model_id, @@ -1433,6 +1433,7 @@ def get_model( FLASH_TRANSFORMERS_BACKEND and transformers_model_class is not None and transformers_model_class._supports_flex_attn + and hasattr(transformers_model_class.config_class, "base_model_tp_plan") ): return TransformersFlashCausalLM.fallback( model_id, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 36de89b4..be035417 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -82,6 +82,73 @@ def tgi_flash_attention_forward( transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward +# Those are actually missing and Transformers so hardcoded here for now! +transformers.models.olmo.configuration_olmo.OlmoConfig.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", +} +transformers.models.olmo2.configuration_olmo2.Olmo2Config.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", +} +transformers.models.cohere.configuration_cohere.CohereConfig.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", +} +transformers.models.cohere.configuration_cohere2.Cohere2Config.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", +} +transformers.models.gemma.configuration_gemma.GemmaConfig.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", +} +transformers.models.helium.configuration_helium.HeliumConfig.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", +} +transformers.models.mixtral.configuration_mixtral.MixtralConfig.base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise", + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", +} + + class TransformersFlashCausalLM(FlashCausalLM): def __init__( self,