From e06320f64e5df68e064831175bb959ed369e4b53 Mon Sep 17 00:00:00 2001 From: Thanaji Rao Thakkalapelli Date: Tue, 15 Oct 2024 10:50:17 -0700 Subject: [PATCH] Enabling Flash Attention support for falcon model (#232) --- server/text_generation_server/models/causal_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 22d7a71e..66d0fcc0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -694,11 +694,11 @@ class CausalLM(Model): "return_dict": True, } - if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2"]: - - if model.config.model_type in ["llama", "mistral", "qwen2"]: + if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]: + + if model.config.model_type not in ["falcon"] kwargs["attn_softmax_bf16"] = True - kwargs["trim_logits"] = True + kwargs["trim_logits"] = True if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": kwargs["use_flash_attention"] = True