From f74a50d41b0ed78b118dfdf2b6ad78e133c52594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 5 Feb 2025 12:58:44 +0000 Subject: [PATCH] Take `load_kernel` out of a frequently-called function --- .../layers/attention/kv_cache.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 522b10a7..2e1dd1c5 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -11,6 +11,18 @@ from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weights +if SYSTEM == "cuda": + try: + attention_kernels = load_kernel( + module="attention", repo_id="kernels-community/attention" + ) + except Exception as e: + raise ImportError( + f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" + ) +else: + attention_kernels = None + @dataclass class KVScales: @@ -221,14 +233,7 @@ def paged_reshape_and_cache( ): if SYSTEM == "cuda": - try: - attention_kernels = load_kernel( - module="attention", repo_id="kernels-community/attention" - ) - except Exception as e: - raise ImportError( - f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" - ) + assert attention_kernels is not None kv_cache_dtype = "auto" if key_cache.dtype == torch.float8_e4m3fn: