mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* feat: add ruff and resolve issue * fix: update client exports and adjust after rebase * fix: adjust syntax to avoid circular import * fix: adjust client ruff settings * fix: lint and refactor import check and avoid model enum as global names * fix: improve fbgemm_gpu check and lints * fix: update lints * fix: prefer comparing model enum over str * fix: adjust lints and ignore specific rules * fix: avoid unneeded quantize check
25 lines
755 B
Python
25 lines
755 B
Python
from text_generation_server.utils.import_utils import SYSTEM
|
|
import os
|
|
|
|
from .common import Seqlen
|
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
if SYSTEM == "cuda":
|
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
|
elif SYSTEM == "rocm":
|
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
|
elif SYSTEM == "ipex":
|
|
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
|
else:
|
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
|
|
|
|
|
__all__ = [
|
|
"attention",
|
|
"paged_attention",
|
|
"reshape_and_cache",
|
|
"SUPPORTS_WINDOWING",
|
|
"Seqlen",
|
|
]
|