wip fa2 triton & fix cudagraph bug

This commit is contained in:
fxmarty 2024-04-19 10:11:39 +00:00
parent b503b3de60
commit 47e522a66a
5 changed files with 52 additions and 29 deletions

View File

@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> {
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
(Some(cuda_graphs), None) => cuda_graphs.clone(),
#[allow(deprecated)]
(
None,

View File

@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
@ -807,7 +808,6 @@ class FlashCausalLM(Model):
self.device,
)
logger.info("CUDA_GRAPHS", CUDA_GRAPHS)
if CUDA_GRAPHS:
try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
@ -817,8 +817,11 @@ class FlashCausalLM(Model):
self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
if IS_ROCM_SYSTEM and TUNABLEOP:
# TODO: fix
if IS_ROCM_SYSTEM and False:
total_seqlens = list(range(16))
for seqlen in total_seqlens:
self.tunableop_warmup(seqlen, max_s, max_bt)

View File

@ -4,11 +4,14 @@ import os
MEM_POOL = torch.cuda.graph_pool_handle()
# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
if cuda_graphs is not None and cuda_graphs != "0":
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs

View File

@ -474,6 +474,8 @@ class Mamba(Model):
self.cuda_graph_warmup(bs)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS='{CUDA_GRAPHS}').")
return None

View File

@ -20,6 +20,18 @@ is_sm94 = major == 9 and minor == 4
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if IS_ROCM_SYSTEM:
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementaion.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementaion.")
try:
try:
import flash_attn_2_cuda
@ -86,7 +98,7 @@ def attention(
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if HAS_FLASH_ATTN_V2_CUDA:
if IS_CUDA_SYSTEM and HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
@ -108,30 +120,7 @@ def attention(
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM:
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif HAS_FLASH_ATTN:
elif IS_CUDA_SYSTEM and HAS_FLASH_ATTN:
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
@ -180,5 +169,30 @@ def attention(
0,
None,
)
elif IS_ROCM_SYSTEM and HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
raise NotImplementedError("flash attention is not installed")
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
raise NotImplementedError("TODO")
else:
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")