fix import

This commit is contained in:
Mohit Sharma 2024-09-27 12:32:17 +00:00
parent 47c81d2924
commit 816d4b67b2

View File

@ -11,6 +11,7 @@ if SYSTEM == "cuda":
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import (
@ -18,6 +19,7 @@ elif SYSTEM == "rocm":
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "ipex":
from .ipex import (
@ -25,6 +27,7 @@ elif SYSTEM == "ipex":
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
@ -35,5 +38,6 @@ __all__ = [
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"PREFILL_IN_KV_CACHE",
"Seqlen",
]