From 8aecc59eb09533c26bc0dee7d36f900bacf096a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 5 Feb 2025 13:09:23 +0000 Subject: [PATCH] Hoist another case of kernel loading out of a somewhat hot function --- .../layers/attention/cuda.py | 15 ++++++++++++--- .../layers/attention/kv_cache.py | 4 +--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index f16004cd..a3afd422 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -14,6 +14,18 @@ major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 _PARTITION_SIZE = 512 +if SYSTEM == "cuda": + try: + attention_kernels = load_kernel( + module="attention", repo_id="kernels-community/attention" + ) + except Exception as e: + raise ImportError( + f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}" + ) +else: + attention_kernels = None + def paged_attention( query: torch.Tensor, @@ -108,9 +120,6 @@ def paged_attention( if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths + seqlen.cache_lengths - attention_kernels = load_kernel( - module="attention", repo_id="kernels-community/attention" - ) out = torch.empty_like(query) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 2e1dd1c5..6e7cb713 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -18,7 +18,7 @@ if SYSTEM == "cuda": ) except Exception as e: raise ImportError( - f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" + f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}" ) else: attention_kernels = None @@ -233,8 +233,6 @@ def paged_reshape_and_cache( ): if SYSTEM == "cuda": - assert attention_kernels is not None - kv_cache_dtype = "auto" if key_cache.dtype == torch.float8_e4m3fn: kv_cache_dtype = "fp8"