From b126bf47853e33d3b8950f6593462f1ca9a9d300 Mon Sep 17 00:00:00 2001 From: Thanaji Rao Thakkalapelli Date: Wed, 23 Oct 2024 01:58:57 -0700 Subject: [PATCH] Revert pr 235 as flash attention is not really enabled for gemma (#239) --- server/text_generation_server/models/causal_lm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 10ebd41c..c15e6e4e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -694,13 +694,12 @@ class CausalLM(Model): "return_dict": True, } - if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon", "gemma"]: + if model.config.model_type in ["llama", "mistral", "starcoder2", "qwen2", "falcon"]: if model.config.model_type not in ["falcon"]: kwargs["attn_softmax_bf16"] = True - if model.config.model_type not in ["gemma"]: - kwargs["trim_logits"] = True + kwargs["trim_logits"] = True if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": kwargs["use_flash_attention"] = True