mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Hoist another case of kernel loading out of a somewhat hot function
This commit is contained in:
parent
f74a50d41b
commit
8aecc59eb0
@ -14,6 +14,18 @@ major, minor = torch.cuda.get_device_capability()
|
|||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
try:
|
||||||
|
attention_kernels = load_kernel(
|
||||||
|
module="attention", repo_id="kernels-community/attention"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_kernels = None
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -108,9 +120,6 @@ def paged_attention(
|
|||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
attention_kernels = load_kernel(
|
|
||||||
module="attention", repo_id="kernels-community/attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ if SYSTEM == "cuda":
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_kernels = None
|
attention_kernels = None
|
||||||
@ -233,8 +233,6 @@ def paged_reshape_and_cache(
|
|||||||
):
|
):
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
assert attention_kernels is not None
|
|
||||||
|
|
||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
if key_cache.dtype == torch.float8_e4m3fn:
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
kv_cache_dtype = "fp8"
|
kv_cache_dtype = "fp8"
|
||||||
|
Loading…
Reference in New Issue
Block a user