Enables Flash Attention in TGI for gemma models (#235)

This commit is contained in:
Thanaji Rao Thakkalapelli 2024-10-18 09:20:42 -07:00 committed by GitHub
parent 9ae5ad5057
commit c5e3881051
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -694,10 +694,12 @@ class CausalLM(Model):
"return_dict": True, "return_dict": True,
} }
if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]: if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]:
if model.config.model_type not in ["falcon"]: if model.config.model_type not in ["falcon"]:
kwargs["attn_softmax_bf16"] = True kwargs["attn_softmax_bf16"] = True
if model.config.model_type not in ["gemma"]:
kwargs["trim_logits"] = True kwargs["trim_logits"] = True
if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true":