From 47e522a66a6574f74da7eed83b7df9060d086486 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:11:39 +0000 Subject: [PATCH] wip fa2 triton & fix cudagraph bug --- launcher/src/main.rs | 1 + .../models/flash_causal_lm.py | 7 +- .../text_generation_server/models/globals.py | 5 +- server/text_generation_server/models/mamba.py | 2 + .../utils/flash_attn.py | 66 +++++++++++-------- 5 files changed, 52 insertions(+), 29 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d904f91b..40e7364f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> { let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { (Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(), + (Some(cuda_graphs), None) => cuda_graphs.clone(), #[allow(deprecated)] ( None, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e474f9d6..9a3db958 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -13,6 +13,7 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.speculate import get_speculate @@ -807,7 +808,6 @@ class FlashCausalLM(Model): self.device, ) - logger.info("CUDA_GRAPHS", CUDA_GRAPHS) if CUDA_GRAPHS: try: logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") @@ -817,8 +817,11 @@ class FlashCausalLM(Model): self.cuda_graph_warmup(bs, max_s, max_bt) except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") + else: + logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - if IS_ROCM_SYSTEM and TUNABLEOP: + # TODO: fix + if IS_ROCM_SYSTEM and False: total_seqlens = list(range(16)) for seqlen in total_seqlens: self.tunableop_warmup(seqlen, max_s, max_bt) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6f554049..91b4225a 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,11 +4,14 @@ import os MEM_POOL = torch.cuda.graph_pool_handle() # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None: +if cuda_graphs is not None and cuda_graphs != "0": try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] except Exception as e: raise RuntimeError( f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}" ) +else: + cuda_graphs = None + CUDA_GRAPHS = cuda_graphs diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 07a81491..2aec4f95 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -474,6 +474,8 @@ class Mamba(Model): self.cuda_graph_warmup(bs) except Exception: logger.exception(f"Decode cuda graph warmup failed") + else: + logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS='{CUDA_GRAPHS}').") return None diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 94073e7d..87a9311b 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -20,6 +20,18 @@ is_sm94 = major == 9 and minor == 4 HAS_FLASH_ATTN = False HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_ROCM = False + +ROCM_USE_FLASH_ATTN_V2_CK = False +ROCM_USE_FLASH_ATTN_V2_TRITON = False + +if IS_ROCM_SYSTEM: + if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true": + ROCM_USE_FLASH_ATTN_V2_TRITON = True + logger.info("ROCm: using Flash Attention 2 Triton implementaion.") + else: + ROCM_USE_FLASH_ATTN_V2_CK = True + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementaion.") + try: try: import flash_attn_2_cuda @@ -86,7 +98,7 @@ def attention( if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - if HAS_FLASH_ATTN_V2_CUDA: + if IS_CUDA_SYSTEM and HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q, k, @@ -108,30 +120,7 @@ def attention( False, None, ) - elif HAS_FLASH_ATTN_V2_ROCM: - if window_size_left != -1: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) - - # RoCm flash API does not take the window_size_left and window_size_right arguments. - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - None, - ) - elif HAS_FLASH_ATTN: + elif IS_CUDA_SYSTEM and HAS_FLASH_ATTN: if window_size_left != -1: raise NotImplementedError( "window_size_left is only available with flash attn v2" @@ -180,5 +169,30 @@ def attention( 0, None, ) + elif IS_ROCM_SYSTEM and HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: + if window_size_left != -1: + raise ValueError( + f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) - raise NotImplementedError("flash attention is not installed") + # RoCm flash API does not take the window_size_left and window_size_right arguments. + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON: + raise NotImplementedError("TODO") + else: + raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")