diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index f16004cd..a3afd422 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -14,6 +14,18 @@ major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 _PARTITION_SIZE = 512 +if SYSTEM == "cuda": + try: + attention_kernels = load_kernel( + module="attention", repo_id="kernels-community/attention" + ) + except Exception as e: + raise ImportError( + f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}" + ) +else: + attention_kernels = None + def paged_attention( query: torch.Tensor, @@ -108,9 +120,6 @@ def paged_attention( if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths + seqlen.cache_lengths - attention_kernels = load_kernel( - module="attention", repo_id="kernels-community/attention" - ) out = torch.empty_like(query) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 2e1dd1c5..6e7cb713 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -18,7 +18,7 @@ if SYSTEM == "cuda": ) except Exception as e: raise ImportError( - f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}" + f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}" ) else: attention_kernels = None @@ -233,8 +233,6 @@ def paged_reshape_and_cache( ): if SYSTEM == "cuda": - assert attention_kernels is not None - kv_cache_dtype = "auto" if key_cache.dtype == torch.float8_e4m3fn: kv_cache_dtype = "fp8"