diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9deef2e1..28d06790 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -43,6 +43,8 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +FLASH_ATTENTION = True + # FlashCausalLM reqiures CUDA Graphs to be enabled on the system. This will throw a RuntimeError # if CUDA Graphs are not available when calling `torch.cuda.graph_pool_handle()` in the FlashCausalLM HAS_CUDA_GRAPH = False