mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
817 lines
26 KiB
Python
817 lines
26 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Fused Attention
|
|
===============
|
|
|
|
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
|
(https://tridao.me/publications/flash2/flash2.pdf)
|
|
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
|
|
|
Features supported:
|
|
|
|
1) Fwd with causal masking
|
|
2) Any sequence lengths without padding (currently fwd kernel only)
|
|
3) Support for different sequence lengths for q and k
|
|
4) Nested tensor API currently does not support dropout or bias.
|
|
|
|
Not currently supported:
|
|
|
|
1) Non power of two head dims
|
|
|
|
"""
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
torch_dtype: tl.constexpr = torch.float16
|
|
|
|
|
|
@triton.jit
|
|
def cdiv_fn(x, y):
|
|
return (x + y - 1) // y
|
|
|
|
|
|
@triton.jit
|
|
def max_fn(x, y):
|
|
return tl.math.max(x, y)
|
|
|
|
|
|
@triton.jit
|
|
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|
ms = tl.arange(0, m)
|
|
ns = tl.arange(0, n)
|
|
return philox_offset + ms[:, None] * stride + ns[None, :]
|
|
|
|
|
|
@triton.jit
|
|
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|
rng_offsets = dropout_offsets(
|
|
philox_seed, philox_offset, dropout_p, m, n, stride
|
|
).to(tl.uint32)
|
|
# TODO: use tl.randint for better performance
|
|
return tl.rand(philox_seed, rng_offsets)
|
|
|
|
|
|
@triton.jit
|
|
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
|
rng_keep = rng_output > dropout_p
|
|
return rng_keep
|
|
|
|
|
|
@triton.jit
|
|
def load_fn(block_ptr, first, second, pad):
|
|
if first and second:
|
|
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
|
elif first:
|
|
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
|
|
elif second:
|
|
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
|
|
else:
|
|
tensor = tl.load(block_ptr)
|
|
return tensor
|
|
|
|
|
|
@triton.jit
|
|
def _attn_fwd_inner(
|
|
acc,
|
|
l_i,
|
|
m_i,
|
|
q,
|
|
K_block_ptr,
|
|
V_block_ptr,
|
|
start_m,
|
|
actual_seqlen_k,
|
|
dropout_p,
|
|
philox_seed,
|
|
batch_philox_offset,
|
|
encoded_softmax_block_ptr,
|
|
block_min,
|
|
block_max,
|
|
offs_n_causal,
|
|
masked_blocks,
|
|
n_extra_tokens,
|
|
bias_ptr,
|
|
IS_CAUSAL: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
OFFS_M: tl.constexpr,
|
|
OFFS_N: tl.constexpr,
|
|
PRE_LOAD_V: tl.constexpr,
|
|
MASK_STEPS: tl.constexpr,
|
|
ENABLE_DROPOUT: tl.constexpr,
|
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
|
PADDED_HEAD: tl.constexpr,
|
|
):
|
|
# loop over k, v, and update accumulator
|
|
for start_n in range(block_min, block_max, BLOCK_N):
|
|
# For padded blocks, we will overrun the tensor size if
|
|
# we load all BLOCK_N. For others, the blocks are all within range.
|
|
k = load_fn(
|
|
K_block_ptr,
|
|
PADDED_HEAD,
|
|
MASK_STEPS and (n_extra_tokens != 0),
|
|
"zero",
|
|
)
|
|
if PRE_LOAD_V:
|
|
v = load_fn(
|
|
V_block_ptr,
|
|
MASK_STEPS and (n_extra_tokens != 0),
|
|
PADDED_HEAD,
|
|
"zero",
|
|
)
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
# We start from end of seqlen_k so only the first iteration would need
|
|
# to be checked for padding if it is not a multiple of block_n
|
|
# TODO: This can be optimized to only be true for the padded block.
|
|
if MASK_STEPS: # noqa: SIM102
|
|
# If this is the last block / iteration, we want to
|
|
# mask if the sequence length is not a multiple of block size
|
|
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
|
# if not is_modulo_mn. last step might get wasted but that is okay.
|
|
# check if this masking works for that case.
|
|
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
|
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
|
size_n = start_n + OFFS_N[None, :]
|
|
mask = size_n < boundary_m[:, None]
|
|
qk = tl.where(mask, qk, float("-inf"))
|
|
if IS_CAUSAL:
|
|
causal_boundary = start_n + offs_n_causal
|
|
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
|
qk = tl.where(causal_mask, qk, float("-inf"))
|
|
# -- compute qk ----
|
|
qk += tl.dot(q, k)
|
|
if bias_ptr is not None:
|
|
bias = load_fn(
|
|
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
|
)
|
|
# While bias is added after multiplying qk with sm_scale, our
|
|
# optimization to use 2^x instead of e^x results in an additional
|
|
# scale factor of log2(e) which we must also multiply the bias with.
|
|
qk += bias * 1.44269504089
|
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
|
qk = qk - m_ij[:, None]
|
|
p = tl.math.exp2(qk)
|
|
|
|
# CAVEAT: Must update l_ij before applying dropout
|
|
l_ij = tl.sum(p, 1)
|
|
if ENABLE_DROPOUT:
|
|
philox_offset = (
|
|
batch_philox_offset
|
|
+ start_m * BLOCK_M * actual_seqlen_k
|
|
+ start_n
|
|
- BLOCK_N
|
|
)
|
|
keep = dropout_mask(
|
|
philox_seed,
|
|
philox_offset,
|
|
dropout_p,
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
actual_seqlen_k,
|
|
)
|
|
if RETURN_ENCODED_SOFTMAX:
|
|
tl.store(
|
|
encoded_softmax_block_ptr,
|
|
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
|
)
|
|
p = tl.where(keep, p, 0.0)
|
|
elif RETURN_ENCODED_SOFTMAX:
|
|
tl.store(
|
|
encoded_softmax_block_ptr,
|
|
p.to(encoded_softmax_block_ptr.type.element_ty),
|
|
)
|
|
# -- update output accumulator --
|
|
alpha = tl.math.exp2(m_i - m_ij)
|
|
acc = acc * alpha[:, None]
|
|
if not PRE_LOAD_V:
|
|
v = load_fn(
|
|
V_block_ptr,
|
|
MASK_STEPS and (n_extra_tokens != 0),
|
|
PADDED_HEAD,
|
|
"zero",
|
|
)
|
|
# -- update m_i and l_i
|
|
l_i = l_i * alpha + l_ij
|
|
# update m_i and l_i
|
|
m_i = m_ij
|
|
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
if bias_ptr is not None:
|
|
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
|
if RETURN_ENCODED_SOFTMAX:
|
|
encoded_softmax_block_ptr = tl.advance(
|
|
encoded_softmax_block_ptr, (0, BLOCK_N)
|
|
)
|
|
return acc, l_i, m_i
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 256,
|
|
"BLOCK_N": 64,
|
|
"waves_per_eu": 2,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=8,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 128,
|
|
"BLOCK_N": 128,
|
|
"waves_per_eu": 2,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 256,
|
|
"BLOCK_N": 128,
|
|
"waves_per_eu": 2,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=8,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 128,
|
|
"BLOCK_N": 64,
|
|
"waves_per_eu": 3,
|
|
"PRE_LOAD_V": True,
|
|
},
|
|
num_stages=1,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 128,
|
|
"BLOCK_N": 64,
|
|
"waves_per_eu": 3,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 64,
|
|
"BLOCK_N": 64,
|
|
"waves_per_eu": 4,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=8,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 32,
|
|
"BLOCK_N": 32,
|
|
"waves_per_eu": 4,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=8,
|
|
),
|
|
# TODO: This config fails with head_size not pow2 with data mismatches.
|
|
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
|
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 16,
|
|
"BLOCK_N": 16,
|
|
"waves_per_eu": 1,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{
|
|
"BLOCK_M": 128,
|
|
"BLOCK_N": 64,
|
|
"waves_per_eu": 1,
|
|
"PRE_LOAD_V": False,
|
|
},
|
|
num_stages=1,
|
|
num_warps=4,
|
|
),
|
|
],
|
|
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
|
)
|
|
@triton.jit
|
|
def attn_fwd(
|
|
Q,
|
|
K,
|
|
V,
|
|
bias,
|
|
sm_scale,
|
|
L,
|
|
Out,
|
|
stride_qz,
|
|
stride_qh,
|
|
stride_qm,
|
|
stride_qk,
|
|
stride_kz,
|
|
stride_kh,
|
|
stride_kn,
|
|
stride_kk,
|
|
stride_vz,
|
|
stride_vh,
|
|
stride_vk,
|
|
stride_vn,
|
|
stride_oz,
|
|
stride_oh,
|
|
stride_om,
|
|
stride_on,
|
|
stride_bz,
|
|
stride_bh,
|
|
stride_bm,
|
|
stride_bn,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
dropout_p,
|
|
philox_seed,
|
|
philox_offset_base,
|
|
encoded_softmax,
|
|
HQ: tl.constexpr,
|
|
HK: tl.constexpr,
|
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
|
MAX_SEQLENS_Q: tl.constexpr,
|
|
MAX_SEQLENS_K: tl.constexpr,
|
|
VARLEN: tl.constexpr,
|
|
IS_CAUSAL: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
PRE_LOAD_V: tl.constexpr,
|
|
BIAS_TYPE: tl.constexpr,
|
|
ENABLE_DROPOUT: tl.constexpr,
|
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
|
):
|
|
start_m = tl.program_id(0)
|
|
off_h_q = tl.program_id(1)
|
|
off_z = tl.program_id(2)
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
if VARLEN:
|
|
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
|
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
|
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
|
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
|
# small for all start_m so for those we return early.
|
|
if start_m * BLOCK_M > seqlen_q:
|
|
return
|
|
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
|
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
|
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
|
else:
|
|
cu_seqlens_q_start = 0
|
|
cu_seqlens_k_start = 0
|
|
seqlen_q = MAX_SEQLENS_Q
|
|
seqlen_k = MAX_SEQLENS_K
|
|
|
|
# Now we compute whether we need to exit early due to causal masking.
|
|
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
|
# are completely masked, resulting in 0s written to the output, and
|
|
# inf written to LSE. We don't need to do any GEMMs in this case.
|
|
# This block of code determines what N is, and if this WG is operating
|
|
# on those M rows.
|
|
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
|
if IS_CAUSAL:
|
|
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
|
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
|
# the causal mask boundary is bottom right aligned, and ends at either
|
|
# the top edge (seqlen_q < seqlen_k) or left edge.
|
|
# This captures the decrease in n_blocks if we have a rectangular attn
|
|
# matrix
|
|
n_blocks_seqlen = cdiv_fn(
|
|
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
|
)
|
|
# This is what adjusts the block_max for the current WG, only
|
|
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
|
n_blocks = min(n_blocks, n_blocks_seqlen)
|
|
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
|
# part of the blocks that are all 0. We exit early.
|
|
if n_blocks <= 0:
|
|
o_offset = (
|
|
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
|
)
|
|
O_block_ptr = tl.make_block_ptr(
|
|
base=Out + o_offset,
|
|
shape=(seqlen_q, BLOCK_DMODEL),
|
|
strides=(stride_om, stride_on),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
|
# We still need to write 0s to the result
|
|
# tl.store(O_block_ptr,
|
|
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
|
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
|
# + offs_m
|
|
# We store inf to LSE, not -inf because in the bwd pass,
|
|
# we subtract this
|
|
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
|
# for these masked blocks.
|
|
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
|
# tl.store(l_ptrs, l)
|
|
# TODO: Should dropout and return encoded softmax be handled here?
|
|
return
|
|
|
|
# If MQA / GQA, set the K and V head offsets appropriately.
|
|
GROUP_SIZE: tl.constexpr = HQ // HK
|
|
if GROUP_SIZE != 1:
|
|
off_h_k = off_h_q // GROUP_SIZE
|
|
else:
|
|
off_h_k = off_h_q
|
|
|
|
n_extra_tokens = 0
|
|
if seqlen_k < BLOCK_N:
|
|
n_extra_tokens = BLOCK_N - seqlen_k
|
|
elif seqlen_k % BLOCK_N:
|
|
n_extra_tokens = seqlen_k % BLOCK_N
|
|
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
|
|
|
# Compute pointers for all the tensors used in this kernel.
|
|
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
|
Q_block_ptr = tl.make_block_ptr(
|
|
base=Q + q_offset,
|
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
|
K_block_ptr = tl.make_block_ptr(
|
|
base=K + k_offset,
|
|
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
|
strides=(stride_kk, stride_kn),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
order=(0, 1),
|
|
)
|
|
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
|
V_block_ptr = tl.make_block_ptr(
|
|
base=V + v_offset,
|
|
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
|
strides=(stride_vk, stride_vn),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
if BIAS_TYPE != 0:
|
|
bias_ptr = tl.make_block_ptr(
|
|
base=bias + off_h_q * stride_bh,
|
|
shape=(seqlen_q, seqlen_k),
|
|
strides=(stride_bm, stride_bn),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
else:
|
|
bias_ptr = None
|
|
if ENABLE_DROPOUT:
|
|
batch_philox_offset = (
|
|
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
|
)
|
|
else:
|
|
batch_philox_offset = 0
|
|
# We can ask to return the dropout mask without actually doing any dropout.
|
|
# In this case, we return an invalid pointer so indicate the mask is not i
|
|
# valid.
|
|
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
|
if RETURN_ENCODED_SOFTMAX:
|
|
encoded_softmax_block_ptr = tl.make_block_ptr(
|
|
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
|
shape=(seqlen_q, seqlen_k),
|
|
strides=(seqlen_k, 1),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
else:
|
|
encoded_softmax_block_ptr = 0
|
|
# initialize pointer to m and l
|
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
|
# have native e^x support in HW.
|
|
qk_scale = sm_scale * 1.44269504089
|
|
# Q is loaded once at the beginning and shared by all N blocks.
|
|
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
|
|
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
|
|
|
# Here we compute how many full and masked blocks we have.
|
|
padded_block_k = n_extra_tokens != 0
|
|
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
|
if IS_CAUSAL:
|
|
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
|
# Additionally there might be one more due to dissimilar seqlens.
|
|
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
|
else:
|
|
# Padding on Q does not need to be masked in the FA loop.
|
|
masked_blocks = padded_block_k
|
|
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
|
# block. In this case we might exceed n_blocks so pick the min.
|
|
masked_blocks = min(masked_blocks, n_blocks)
|
|
n_full_blocks = n_blocks - masked_blocks
|
|
block_min = 0
|
|
block_max = n_blocks * BLOCK_N
|
|
# Compute for full blocks. Here we set causal to false regardless of its
|
|
# value because there is no masking. Similarly we do not need padding.
|
|
if n_full_blocks > 0:
|
|
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
|
acc, l_i, m_i = _attn_fwd_inner(
|
|
acc,
|
|
l_i,
|
|
m_i,
|
|
q,
|
|
K_block_ptr,
|
|
V_block_ptr,
|
|
start_m,
|
|
seqlen_k,
|
|
dropout_p,
|
|
philox_seed,
|
|
batch_philox_offset,
|
|
encoded_softmax_block_ptr,
|
|
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
|
block_min,
|
|
block_max,
|
|
0,
|
|
0,
|
|
0,
|
|
bias_ptr,
|
|
# IS_CAUSAL, ....
|
|
False,
|
|
BLOCK_M,
|
|
BLOCK_DMODEL,
|
|
BLOCK_N,
|
|
offs_m,
|
|
offs_n,
|
|
# _, MASK_STEPS, ...
|
|
PRE_LOAD_V,
|
|
False,
|
|
ENABLE_DROPOUT,
|
|
RETURN_ENCODED_SOFTMAX,
|
|
PADDED_HEAD,
|
|
)
|
|
block_min = block_max
|
|
block_max = n_blocks * BLOCK_N
|
|
|
|
tl.debug_barrier()
|
|
# Remaining blocks, if any, are full / not masked.
|
|
if masked_blocks > 0:
|
|
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
|
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
|
if bias_ptr is not None:
|
|
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
|
if RETURN_ENCODED_SOFTMAX:
|
|
encoded_softmax_block_ptr = tl.advance(
|
|
encoded_softmax_block_ptr, (0, n_full_blocks)
|
|
)
|
|
acc, l_i, m_i = _attn_fwd_inner(
|
|
acc,
|
|
l_i,
|
|
m_i,
|
|
q,
|
|
K_block_ptr,
|
|
V_block_ptr,
|
|
start_m,
|
|
seqlen_k,
|
|
dropout_p,
|
|
philox_seed,
|
|
batch_philox_offset,
|
|
encoded_softmax_block_ptr,
|
|
block_min,
|
|
block_max,
|
|
offs_n_causal,
|
|
masked_blocks,
|
|
n_extra_tokens,
|
|
bias_ptr,
|
|
IS_CAUSAL,
|
|
BLOCK_M,
|
|
BLOCK_DMODEL,
|
|
BLOCK_N,
|
|
offs_m,
|
|
offs_n,
|
|
# _, MASK_STEPS, ...
|
|
PRE_LOAD_V,
|
|
True,
|
|
ENABLE_DROPOUT,
|
|
RETURN_ENCODED_SOFTMAX,
|
|
PADDED_HEAD,
|
|
)
|
|
# epilogue
|
|
acc = acc / l_i[:, None]
|
|
if ENABLE_DROPOUT:
|
|
acc = acc / (1 - dropout_p)
|
|
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
|
# then we have one block with a row of all NaNs which come from computing
|
|
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
|
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
|
end_m_idx = (start_m + 1) * BLOCK_M
|
|
start_m_idx = start_m * BLOCK_M
|
|
causal_start_idx = seqlen_q - seqlen_k
|
|
acc = acc.to(Out.type.element_ty)
|
|
if IS_CAUSAL: # noqa: SIM102
|
|
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
|
out_mask_boundary = tl.full(
|
|
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
|
)
|
|
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
|
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
|
z = 0.0
|
|
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
|
# write back LSE
|
|
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
|
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
|
# few rows. This is only true for the last M block. For others,
|
|
# overflow_size will be -ve
|
|
# overflow_size = end_m_idx - seqlen_q
|
|
# if overflow_size > 0:
|
|
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
|
# # This is a > check because mask being 0 blocks the store.
|
|
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
|
# else:
|
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
|
|
|
# write back O
|
|
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
|
O_block_ptr = tl.make_block_ptr(
|
|
base=Out + o_offset,
|
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
|
strides=(stride_om, stride_on),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
# Need boundary check on this to make sure the padding from the
|
|
# Q and KV tensors in both dims are not part of what we store back.
|
|
# TODO: Do the boundary check optionally.
|
|
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
|
|
|
|
|
def check_args(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
varlen=True,
|
|
max_seqlens=None,
|
|
cu_seqlens_q=None,
|
|
cu_seqlens_k=None,
|
|
):
|
|
assert q.dim() == k.dim() and q.dim() == v.dim()
|
|
if varlen:
|
|
assert q.dim() == 3
|
|
total_q, nheads_q, head_size = q.shape
|
|
total_k, nheads_k, _ = k.shape
|
|
assert cu_seqlens_q is not None
|
|
assert cu_seqlens_k is not None
|
|
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
|
else:
|
|
assert q.dim() == 4
|
|
batch, nheads_q, seqlen_q, head_size = q.shape
|
|
_, nheads_k, seqlen_k, _ = k.shape
|
|
assert max_seqlens > 0
|
|
assert k.shape == v.shape
|
|
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
|
# TODO: Change assert if we support qkl f8 and v f16
|
|
assert q.dtype == k.dtype and q.dtype == v.dtype
|
|
# TODO: Fix assert to check head size <=256 once supported
|
|
assert head_size <= 128
|
|
assert o.shape == q.shape
|
|
assert (nheads_q % nheads_k) == 0
|
|
|
|
|
|
class _attention(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlens_q,
|
|
max_seqlens_k,
|
|
causal=False,
|
|
sm_scale=1.0,
|
|
bias=None,
|
|
):
|
|
if o is None:
|
|
o = torch.empty_like(q, dtype=v.dtype)
|
|
|
|
check_args(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
varlen=True,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
)
|
|
if True: # varlen
|
|
total_q, nheads_q, head_size = q.shape
|
|
total_k, nheads_k, _ = k.shape
|
|
batch = len(cu_seqlens_q) - 1
|
|
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
|
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
|
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
|
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
|
else:
|
|
batch, seqlen_q, nheads_q, head_size = q.shape
|
|
_, seqlen_k, nheads_k, _ = k.shape
|
|
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
|
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
|
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
|
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
|
|
|
# Get closest power of 2 over or equal to 32.
|
|
padded_d_model = 1 << (head_size - 1).bit_length()
|
|
padded_d_model = max(padded_d_model, 16)
|
|
|
|
grid = lambda META: (
|
|
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
|
nheads_q,
|
|
batch,
|
|
)
|
|
|
|
encoded_softmax = None
|
|
|
|
# Seed the RNG so we get reproducible results for testing.
|
|
philox_seed = 0x1BF52
|
|
philox_offset = 0x1D4B42
|
|
|
|
if bias is not None:
|
|
bias_strides = (
|
|
bias.stride(0),
|
|
bias.stride(1),
|
|
bias.stride(2),
|
|
bias.stride(3),
|
|
)
|
|
else:
|
|
bias_strides = (0, 0, 0, 0)
|
|
|
|
attn_fwd[grid](
|
|
q,
|
|
k,
|
|
v,
|
|
bias,
|
|
sm_scale,
|
|
None,
|
|
o,
|
|
*q_strides,
|
|
*k_strides,
|
|
*v_strides,
|
|
*o_strides,
|
|
*bias_strides,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
dropout_p=0.0,
|
|
philox_seed=philox_seed,
|
|
philox_offset_base=philox_offset,
|
|
encoded_softmax=encoded_softmax,
|
|
HQ=nheads_q,
|
|
HK=nheads_k,
|
|
ACTUAL_BLOCK_DMODEL=head_size,
|
|
MAX_SEQLENS_Q=max_seqlens_q,
|
|
MAX_SEQLENS_K=max_seqlens_k,
|
|
IS_CAUSAL=causal,
|
|
VARLEN=True,
|
|
BLOCK_DMODEL=padded_d_model,
|
|
BIAS_TYPE=0 if bias is None else 1,
|
|
ENABLE_DROPOUT=False,
|
|
RETURN_ENCODED_SOFTMAX=False,
|
|
)
|
|
|
|
ctx.grid = grid
|
|
ctx.sm_scale = sm_scale
|
|
ctx.BLOCK_DMODEL = head_size
|
|
ctx.causal = causal
|
|
ctx.dropout_p = 0.0
|
|
ctx.philox_seed = philox_seed
|
|
ctx.philox_offset = philox_offset
|
|
ctx.encoded_softmax = encoded_softmax
|
|
ctx.return_encoded_softmax = False
|
|
return o, encoded_softmax
|
|
|
|
|
|
triton_attention = _attention.apply
|