Hoist another case of kernel loading out of a somewhat hot function

This commit is contained in:
Daniël de Kok 2025-02-05 13:09:23 +00:00
parent f74a50d41b
commit 8aecc59eb0
2 changed files with 13 additions and 6 deletions

View File

@ -14,6 +14,18 @@ major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
if SYSTEM == "cuda":
try:
attention_kernels = load_kernel(
module="attention", repo_id="kernels-community/attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
attention_kernels = None
def paged_attention(
query: torch.Tensor,
@ -108,9 +120,6 @@ def paged_attention(
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
attention_kernels = load_kernel(
module="attention", repo_id="kernels-community/attention"
)
out = torch.empty_like(query)

View File

@ -18,7 +18,7 @@ if SYSTEM == "cuda":
)
except Exception as e:
raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
attention_kernels = None
@ -233,8 +233,6 @@ def paged_reshape_and_cache(
):
if SYSTEM == "cuda":
assert attention_kernels is not None
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8"