2024-05-31 15:57:01 +00:00
|
|
|
import os
|
|
|
|
|
2024-10-04 15:51:48 +00:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
|
2024-07-01 21:28:00 +00:00
|
|
|
from .common import Seqlen
|
|
|
|
|
2024-05-31 15:57:01 +00:00
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
|
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
|
|
if SYSTEM == "cuda":
|
2024-08-20 09:15:30 +00:00
|
|
|
from .cuda import (
|
2024-10-04 15:51:48 +00:00
|
|
|
SUPPORTS_WINDOWING,
|
2024-08-20 09:15:30 +00:00
|
|
|
attention,
|
|
|
|
paged_attention,
|
|
|
|
)
|
2024-05-31 15:57:01 +00:00
|
|
|
elif SYSTEM == "rocm":
|
2024-09-27 14:19:42 +00:00
|
|
|
from .rocm import (
|
2024-10-04 15:51:48 +00:00
|
|
|
SUPPORTS_WINDOWING,
|
2024-09-27 14:19:42 +00:00
|
|
|
attention,
|
|
|
|
paged_attention,
|
|
|
|
)
|
2024-06-25 11:20:57 +00:00
|
|
|
elif SYSTEM == "ipex":
|
2024-09-27 14:19:42 +00:00
|
|
|
from .ipex import (
|
2024-10-04 15:51:48 +00:00
|
|
|
SUPPORTS_WINDOWING,
|
2024-09-27 14:19:42 +00:00
|
|
|
attention,
|
|
|
|
paged_attention,
|
|
|
|
)
|
2024-05-31 15:57:01 +00:00
|
|
|
else:
|
|
|
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
2024-07-26 14:29:09 +00:00
|
|
|
|
2024-10-04 15:51:48 +00:00
|
|
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
2024-10-24 14:36:18 +00:00
|
|
|
from .kv_cache import KVCache, get_kv_scales
|
2024-07-26 14:29:09 +00:00
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"attention",
|
2024-10-24 14:36:18 +00:00
|
|
|
"get_kv_scales",
|
2024-07-26 14:29:09 +00:00
|
|
|
"paged_attention",
|
|
|
|
"SUPPORTS_WINDOWING",
|
2024-10-04 15:51:48 +00:00
|
|
|
"KVCache",
|
2024-07-26 14:29:09 +00:00
|
|
|
"Seqlen",
|
|
|
|
]
|