From 8ec57558cd5b7b2ad3eaacdd6295a3db0c9092d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 17 Oct 2024 14:54:22 +0200 Subject: [PATCH] Break cycle between the attention implementations and KV cache (#2627) --- .../layers/attention/__init__.py | 4 -- .../layers/attention/cuda.py | 25 ------------ .../layers/attention/ipex.py | 13 ------- .../layers/attention/kv_cache.py | 39 ++++++++++++++++++- .../layers/attention/rocm.py | 24 ------------ 5 files changed, 37 insertions(+), 68 deletions(-) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index b7ca36f1..b1d7b864 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -11,21 +11,18 @@ if SYSTEM == "cuda": SUPPORTS_WINDOWING, attention, paged_attention, - reshape_and_cache, ) elif SYSTEM == "rocm": from .rocm import ( SUPPORTS_WINDOWING, attention, paged_attention, - reshape_and_cache, ) elif SYSTEM == "ipex": from .ipex import ( SUPPORTS_WINDOWING, attention, paged_attention, - reshape_and_cache, ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") @@ -36,7 +33,6 @@ from .kv_cache import KVCache __all__ = [ "attention", "paged_attention", - "reshape_and_cache", "SUPPORTS_WINDOWING", "KVCache", "Seqlen", diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 5846bfe5..23f3404c 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 _PARTITION_SIZE = 512 -try: - from vllm._C import cache_ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION in {"flashdecoding", "flashinfer"}: - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - def paged_attention( query: torch.Tensor, @@ -346,5 +322,4 @@ __all__ = [ "SUPPORTS_WINDOWING", "attention", "paged_attention", - "reshape_and_cache", ] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 5d159796..e76bb1f4 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -47,18 +47,6 @@ def attention( return out -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - - def paged_attention( query: torch.Tensor, kv_cache: KVCache, @@ -94,5 +82,4 @@ __all__ = [ "SUPPORTS_WINDOWING", "attention", "paged_attention", - "reshape_and_cache", ] diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index e6091a5f..d64302c6 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -115,6 +115,41 @@ class KVCache: key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: - from text_generation_server.layers.attention import reshape_and_cache + paged_reshape_and_cache(key, value, key_cache, value_cache, slots) - reshape_and_cache(key, value, key_cache, value_cache, slots) + +def paged_reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + if SYSTEM == "cuda": + try: + from vllm._C import cache_ops + except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) + elif SYSTEM == "rocm": + try: + import vllm._custom_ops as ops + except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + else: + raise NotImplementedError( + f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention" + ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 986b16e8..47bf5539 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,7 +3,6 @@ from typing import Optional import torch from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import ATTENTION from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master from loguru import logger @@ -28,28 +27,6 @@ except ImportError as e: ) use_rocm_custom_paged_attn = False -try: - import vllm._custom_ops as ops -except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if ATTENTION == "flashdecoding": - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) - def paged_attention( query: torch.Tensor, @@ -305,5 +282,4 @@ __all__ = [ "SUPPORTS_WINDOWING", "attention", "paged_attention", - "reshape_and_cache", ]