Enabling Flash Attention support for falcon model (#232)

This commit is contained in:
Thanaji Rao Thakkalapelli 2024-10-15 10:50:17 -07:00 committed by GitHub
parent 0578bd917d
commit e06320f64e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -694,11 +694,11 @@ class CausalLM(Model):
"return_dict": True,
}
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2"]:
if model.config.model_type in ["llama", "mistral", "qwen2"]:
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]:
if model.config.model_type not in ["falcon"]
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
kwargs["trim_logits"] = True
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":
kwargs["use_flash_attention"] = True