text-generation-inference/server/text_generation_server/layers/attention/__init__.py
Daniël de Kok 59ea38cbca
Simplify the attention function (#2609)
* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
2024-10-17 10:42:52 +02:00

44 lines
999 B
Python

import os
from text_generation_server.utils.import_utils import SYSTEM
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",
]