mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Take load_kernel
out of a frequently-called function
This commit is contained in:
parent
875ce6d521
commit
f74a50d41b
@ -11,6 +11,18 @@ from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weights
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
attention_kernels = load_kernel(
|
||||
module="attention", repo_id="kernels-community/attention"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
else:
|
||||
attention_kernels = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVScales:
|
||||
@ -221,14 +233,7 @@ def paged_reshape_and_cache(
|
||||
):
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
attention_kernels = load_kernel(
|
||||
module="attention", repo_id="kernels-community/attention"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
assert attention_kernels is not None
|
||||
|
||||
kv_cache_dtype = "auto"
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
|
Loading…
Reference in New Issue
Block a user