diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 2e2b7ba9..2d3601c8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -9,7 +9,6 @@ _PARTITION_SIZE = 512 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" -from .flash_attn_triton import triton_attention try: from vllm._C import cache_ops @@ -122,44 +121,42 @@ def paged_attention( ) -try: - import flash_attn_2_cuda - - if ENGINE == "triton": - logger.info("ROCm: using Flash Attention 2 Triton implementation.") - elif ENGINE == "ck": - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") -except ImportError: +if ENGINE != "triton": try: - import flash_attn_cuda + import flash_attn_2_cuda - ENGINE = "v1" - logger.info("ROCm: using Flash Attention 1") - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError: + try: + import flash_attn_cuda - for idx in range(torch.cuda.device_count()): - name = torch.cuda.get_device_name(idx) - if "MI210" not in name and "MI250" not in name: - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - raise ImportError( - f"AMD GPU with Rocm capability {major} {minor} is not supported" - ) from e + ENGINE = "v1" + logger.info("ROCm: using Flash Attention 1") + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e SUPPORTS_WINDOWING = ENGINE != "v1" @@ -180,7 +177,7 @@ if ENGINE == "ck": raise ValueError("`window_size_left` must be > 0 or -1") if window_size_left != -1: raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." ) return flash_attn_2_cuda.varlen_fwd( q, @@ -205,6 +202,7 @@ if ENGINE == "ck": ) elif ENGINE == "triton": + from .flash_attn_triton import triton_attention def attention( q, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index ab6cf02a..d489c3ba 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0],