This commit is contained in:
fxmarty 2024-05-16 14:46:47 +00:00
parent 0812e3bdc9
commit 265c76d328
33 changed files with 123 additions and 84 deletions

View File

@ -8,6 +8,7 @@ if SYSTEM == "rocm":
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class FastLinear(torch.nn.Module):
def __init__(
self,

View File

@ -197,7 +197,9 @@ class LlamaMLP(nn.Module):
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
"tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
)
@ -229,7 +231,11 @@ class LlamaMLP(nn.Module):
)
def forward(self, hidden_states):
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,

View File

@ -266,7 +266,9 @@ class MistralMLP(nn.Module):
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
"tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
)
@ -289,7 +291,11 @@ class MistralMLP(nn.Module):
)
def forward(self, hidden_states):
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,

View File

@ -64,6 +64,7 @@ elif SYSTEM == "rocm":
else:
raise RuntimeError(f"Unsupported system {SYSTEM}")
@dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None

View File

@ -831,16 +831,26 @@ class FlashCausalLM(Model):
torch.cuda.tunable.tuning_enable(True)
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
tuning_sequences = [
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
tuning_sequences = [1, 2, 4, 8, 16, 32]
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
tunableop_filepath = os.path.join(
"/data",
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")
logger.info(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}."
)
if os.path.isfile(tunableop_filepath):
logger.info(f"The file {tunableop_filepath} already exists and will be reused.")
logger.info(
f"The file {tunableop_filepath} already exists and will be reused."
)
torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs("/data", exist_ok=True)
@ -851,7 +861,9 @@ class FlashCausalLM(Model):
torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False)
else:
logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.")
logger.info(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
)
if CUDA_GRAPHS:
try:
@ -877,7 +889,9 @@ class FlashCausalLM(Model):
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32),
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,

View File

@ -187,6 +187,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
total_ns=time.time_ns() - start,
)
def serve(
model_id: str,
revision: Optional[str],

View File

@ -64,12 +64,17 @@ if SYSTEM in {"cuda", "rocm"}:
is_sm94 = major == 9 and minor == 4
if SYSTEM == "rocm":
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1":
if (
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
):
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
logger.info(
"ROCm: using Flash Attention 2 Composable Kernel implementation."
)
try:
try:
@ -158,6 +163,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
def attention(
q,
k,
@ -192,6 +198,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
def attention(
@ -217,6 +224,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
softmax_scale,
)
return output
elif HAS_FLASH_ATTN:
def attention(

View File

@ -46,16 +46,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
stride).to(tl.uint32)
rng_offsets = dropout_offsets(
philox_seed, philox_offset, dropout_p, m, n, stride
).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
stride)
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
@ -65,9 +65,9 @@ def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@ -133,9 +133,7 @@ def _attn_fwd_inner(
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M],
actual_seqlen_k,
dtype=tl.int32)
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
@ -146,8 +144,9 @@ def _attn_fwd_inner(
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS
and (n_extra_tokens != 0), "zero")
bias = load_fn(
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
@ -159,9 +158,12 @@ def _attn_fwd_inner(
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
philox_offset = (
batch_philox_offset
+ start_m * BLOCK_M * actual_seqlen_k
+ start_n
- BLOCK_N
)
keep = dropout_mask(
philox_seed,
philox_offset,
@ -173,8 +175,7 @@ def _attn_fwd_inner(
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
@ -202,8 +203,9 @@ def _attn_fwd_inner(
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, BLOCK_N)
)
return acc, l_i, m_i
@ -341,7 +343,7 @@ def attn_fwd(
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK:tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
@ -392,15 +394,17 @@ def attn_fwd(
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
o_offset = (
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
@ -436,11 +440,10 @@ def attn_fwd(
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
@ -449,8 +452,7 @@ def attn_fwd(
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
@ -459,8 +461,7 @@ def attn_fwd(
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
@ -481,9 +482,9 @@ def attn_fwd(
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
batch_philox_offset = (
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
)
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
@ -578,8 +579,9 @@ def attn_fwd(
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, n_full_blocks))
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, n_full_blocks)
)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
@ -626,12 +628,11 @@ def attn_fwd(
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
out_mask_boundary = tl.full(
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
@ -649,8 +650,7 @@ def attn_fwd(
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
off_h_q * stride_oh)
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),

View File

@ -10,7 +10,9 @@ else:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(