mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Hoist another case of kernel loading out of a somewhat hot function
This commit is contained in:
parent
f74a50d41b
commit
8aecc59eb0
@ -14,6 +14,18 @@ major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
attention_kernels = load_kernel(
|
||||
module="attention", repo_id="kernels-community/attention"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
else:
|
||||
attention_kernels = None
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
@ -108,9 +120,6 @@ def paged_attention(
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
attention_kernels = load_kernel(
|
||||
module="attention", repo_id="kernels-community/attention"
|
||||
)
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
|
@ -18,7 +18,7 @@ if SYSTEM == "cuda":
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
else:
|
||||
attention_kernels = None
|
||||
@ -233,8 +233,6 @@ def paged_reshape_and_cache(
|
||||
):
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
assert attention_kernels is not None
|
||||
|
||||
kv_cache_dtype = "auto"
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
kv_cache_dtype = "fp8"
|
||||
|
Loading…
Reference in New Issue
Block a user