text-generation-inference/server/text_generation_server/layers/attention/__init__.py
drbh bab02ff2bc
feat: add ruff and resolve issue (#2262)
* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
2024-07-26 10:29:09 -04:00

25 lines
755 B
Python

from text_generation_server.utils.import_utils import SYSTEM
import os
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"Seqlen",
]