mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-18 15:22:09 +00:00
tp monkey patch
This commit is contained in:
parent
6cb41a80a1
commit
de83178bc3
@ -848,7 +848,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -898,7 +898,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||
elif model_type == LLAMA or model_type == GRANITE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -934,7 +934,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == BAICHUAN:
|
||||
elif model_type == BAICHUAN or model_type == PHI3:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
@ -1433,6 +1433,7 @@ def get_model(
|
||||
FLASH_TRANSFORMERS_BACKEND
|
||||
and transformers_model_class is not None
|
||||
and transformers_model_class._supports_flex_attn
|
||||
and hasattr(transformers_model_class.config_class, "base_model_tp_plan")
|
||||
):
|
||||
return TransformersFlashCausalLM.fallback(
|
||||
model_id,
|
||||
|
@ -82,6 +82,73 @@ def tgi_flash_attention_forward(
|
||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||
|
||||
|
||||
# Those are actually missing and Transformers so hardcoded here for now!
|
||||
transformers.models.olmo.configuration_olmo.OlmoConfig.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
transformers.models.olmo2.configuration_olmo2.Olmo2Config.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
transformers.models.cohere.configuration_cohere.CohereConfig.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
transformers.models.cohere.configuration_cohere2.Cohere2Config.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
transformers.models.gemma.configuration_gemma.GemmaConfig.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
transformers.models.helium.configuration_helium.HeliumConfig.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
transformers.models.mixtral.configuration_mixtral.MixtralConfig.base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.block_sparse_moe.gate": "colwise",
|
||||
"layers.*.block_sparse_moe.experts.*.w1": "colwise",
|
||||
"layers.*.block_sparse_moe.experts.*.w2": "rowwise",
|
||||
"layers.*.block_sparse_moe.experts.*.w3": "colwise",
|
||||
}
|
||||
|
||||
|
||||
class TransformersFlashCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user