text-generation-inference/server/text_generation_server/layers/attention/__init__.py
Wang, Yi b64c70c9e7
Cpu tgi (#1936)
* add CPU tgi support

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* ipex distributed ops support

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Funtowicz Morgan <mfuntowicz@users.noreply.github.com>
2024-06-25 12:21:29 +02:00

14 lines
611 B
Python

from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
import os
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif IPEX_AVAIL:
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")