Take load_kernel out of a frequently-called function

This commit is contained in:
Daniël de Kok 2025-02-05 12:58:44 +00:00
parent 875ce6d521
commit f74a50d41b

View File

@ -11,6 +11,18 @@ from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
if SYSTEM == "cuda":
try:
attention_kernels = load_kernel(
module="attention", repo_id="kernels-community/attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
)
else:
attention_kernels = None
@dataclass
class KVScales:
@ -221,14 +233,7 @@ def paged_reshape_and_cache(
):
if SYSTEM == "cuda":
try:
attention_kernels = load_kernel(
module="attention", repo_id="kernels-community/attention"
)
except Exception as e:
raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
)
assert attention_kernels is not None
kv_cache_dtype = "auto"
if key_cache.dtype == torch.float8_e4m3fn: