diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index 2134d857..a2f97700 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -11,6 +11,7 @@ if SYSTEM == "cuda": paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) elif SYSTEM == "rocm": from .rocm import ( @@ -18,6 +19,7 @@ elif SYSTEM == "rocm": paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) elif SYSTEM == "ipex": from .ipex import ( @@ -25,6 +27,7 @@ elif SYSTEM == "ipex": paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") @@ -35,5 +38,6 @@ __all__ = [ "paged_attention", "reshape_and_cache", "SUPPORTS_WINDOWING", + "PREFILL_IN_KV_CACHE", "Seqlen", ]