mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Take load_kernel
out of a frequently-called function
This commit is contained in:
parent
875ce6d521
commit
f74a50d41b
@ -11,6 +11,18 @@ from text_generation_server.utils.kernels import load_kernel
|
|||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
try:
|
||||||
|
attention_kernels = load_kernel(
|
||||||
|
module="attention", repo_id="kernels-community/attention"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_kernels = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KVScales:
|
class KVScales:
|
||||||
@ -221,14 +233,7 @@ def paged_reshape_and_cache(
|
|||||||
):
|
):
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
try:
|
assert attention_kernels is not None
|
||||||
attention_kernels = load_kernel(
|
|
||||||
module="attention", repo_id="kernels-community/attention"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
if key_cache.dtype == torch.float8_e4m3fn:
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
Loading…
Reference in New Issue
Block a user