mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
wip fa2 triton & fix cudagraph bug
This commit is contained in:
parent
b503b3de60
commit
47e522a66a
@ -1379,6 +1379,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
|
|
||||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||||
|
(Some(cuda_graphs), None) => cuda_graphs.clone(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(
|
||||||
None,
|
None,
|
||||||
|
@ -13,6 +13,7 @@ from opentelemetry import trace
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
@ -807,7 +808,6 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("CUDA_GRAPHS", CUDA_GRAPHS)
|
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||||
@ -817,8 +817,11 @@ class FlashCausalLM(Model):
|
|||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
else:
|
||||||
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
if IS_ROCM_SYSTEM and TUNABLEOP:
|
# TODO: fix
|
||||||
|
if IS_ROCM_SYSTEM and False:
|
||||||
total_seqlens = list(range(16))
|
total_seqlens = list(range(16))
|
||||||
for seqlen in total_seqlens:
|
for seqlen in total_seqlens:
|
||||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||||
|
@ -4,11 +4,14 @@ import os
|
|||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None and cuda_graphs != "0":
|
||||||
try:
|
try:
|
||||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
cuda_graphs = None
|
||||||
|
|
||||||
CUDA_GRAPHS = cuda_graphs
|
CUDA_GRAPHS = cuda_graphs
|
||||||
|
@ -474,6 +474,8 @@ class Mamba(Model):
|
|||||||
self.cuda_graph_warmup(bs)
|
self.cuda_graph_warmup(bs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
else:
|
||||||
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS='{CUDA_GRAPHS}').")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -20,6 +20,18 @@ is_sm94 = major == 9 and minor == 4
|
|||||||
HAS_FLASH_ATTN = False
|
HAS_FLASH_ATTN = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
HAS_FLASH_ATTN_V2_ROCM = False
|
||||||
|
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_CK = False
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true":
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Triton implementaion.")
|
||||||
|
else:
|
||||||
|
ROCM_USE_FLASH_ATTN_V2_CK = True
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementaion.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
@ -86,7 +98,7 @@ def attention(
|
|||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
if HAS_FLASH_ATTN_V2_CUDA:
|
if IS_CUDA_SYSTEM and HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -108,30 +120,7 @@ def attention(
|
|||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
elif IS_CUDA_SYSTEM and HAS_FLASH_ATTN:
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
elif HAS_FLASH_ATTN:
|
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
@ -180,5 +169,30 @@ def attention(
|
|||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
elif IS_ROCM_SYSTEM and HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
|
||||||
raise NotImplementedError("flash attention is not installed")
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
elif IS_ROCM_SYSTEM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||||
|
raise NotImplementedError("TODO")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Flash attention is not installed (IS_CUDA_SYSTEM={IS_CUDA_SYSTEM}, IS_ROCM_SYSTEM={IS_ROCM_SYSTEM}, HAS_FLASH_ATTN_V2_CUDA={HAS_FLASH_ATTN_V2_CUDA}, HAS_FLASH_ATTN_V2_ROCM={HAS_FLASH_ATTN_V2_ROCM})")
|
||||||
|
Loading…
Reference in New Issue
Block a user