mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
black
This commit is contained in:
parent
0812e3bdc9
commit
265c76d328
@ -8,6 +8,7 @@ if SYSTEM == "rocm":
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -197,7 +197,9 @@ class LlamaMLP(nn.Module):
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
"tanh"
|
||||
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -229,7 +231,11 @@ class LlamaMLP(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
|
@ -266,7 +266,9 @@ class MistralMLP(nn.Module):
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
"tanh"
|
||||
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -289,7 +291,11 @@ class MistralMLP(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
|
@ -64,6 +64,7 @@ elif SYSTEM == "rocm":
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
@ -831,16 +831,26 @@ class FlashCausalLM(Model):
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
|
||||
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
|
||||
tuning_sequences = [
|
||||
int(val)
|
||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||
]
|
||||
else:
|
||||
tuning_sequences = [1, 2, 4, 8, 16, 32]
|
||||
|
||||
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
|
||||
tunableop_filepath = os.path.join(
|
||||
"/data",
|
||||
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||
)
|
||||
|
||||
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")
|
||||
logger.info(
|
||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}."
|
||||
)
|
||||
|
||||
if os.path.isfile(tunableop_filepath):
|
||||
logger.info(f"The file {tunableop_filepath} already exists and will be reused.")
|
||||
logger.info(
|
||||
f"The file {tunableop_filepath} already exists and will be reused."
|
||||
)
|
||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||
|
||||
os.makedirs("/data", exist_ok=True)
|
||||
@ -851,7 +861,9 @@ class FlashCausalLM(Model):
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
else:
|
||||
logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.")
|
||||
logger.info(
|
||||
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
||||
)
|
||||
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
@ -877,7 +889,9 @@ class FlashCausalLM(Model):
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32),
|
||||
cu_seqlen_prefill=torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
|
@ -187,6 +187,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
total_ns=time.time_ns() - start,
|
||||
)
|
||||
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
|
@ -64,12 +64,17 @@ if SYSTEM in {"cuda", "rocm"}:
|
||||
is_sm94 = major == 9 and minor == 4
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1":
|
||||
if (
|
||||
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
|
||||
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
|
||||
):
|
||||
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
||||
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
||||
else:
|
||||
ROCM_USE_FLASH_ATTN_V2_CK = True
|
||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||
logger.info(
|
||||
"ROCm: using Flash Attention 2 Composable Kernel implementation."
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
@ -158,6 +163,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
||||
)
|
||||
|
||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
@ -192,6 +198,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
|
||||
def attention(
|
||||
@ -217,6 +224,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
softmax_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
elif HAS_FLASH_ATTN:
|
||||
|
||||
def attention(
|
||||
|
@ -46,16 +46,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride).to(tl.uint32)
|
||||
rng_offsets = dropout_offsets(
|
||||
philox_seed, philox_offset, dropout_p, m, n, stride
|
||||
).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||
stride)
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
@ -133,9 +133,7 @@ def _attn_fwd_inner(
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M],
|
||||
actual_seqlen_k,
|
||||
dtype=tl.int32)
|
||||
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
@ -146,8 +144,9 @@ def _attn_fwd_inner(
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||
and (n_extra_tokens != 0), "zero")
|
||||
bias = load_fn(
|
||||
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
||||
)
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
@ -159,9 +158,12 @@ def _attn_fwd_inner(
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (batch_philox_offset +
|
||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||
BLOCK_N)
|
||||
philox_offset = (
|
||||
batch_philox_offset
|
||||
+ start_m * BLOCK_M * actual_seqlen_k
|
||||
+ start_n
|
||||
- BLOCK_N
|
||||
)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
@ -173,8 +175,7 @@ def _attn_fwd_inner(
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p,
|
||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
@ -202,8 +203,9 @@ def _attn_fwd_inner(
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, BLOCK_N))
|
||||
encoded_softmax_block_ptr = tl.advance(
|
||||
encoded_softmax_block_ptr, (0, BLOCK_N)
|
||||
)
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@ -392,15 +394,17 @@ def attn_fwd(
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
||||
)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
o_offset = (
|
||||
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
@ -436,11 +440,10 @@ def attn_fwd(
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
|
||||
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||
cu_seqlens_q_start * stride_qm)
|
||||
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
@ -449,8 +452,7 @@ def attn_fwd(
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||
cu_seqlens_k_start * stride_kn)
|
||||
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
@ -459,8 +461,7 @@ def attn_fwd(
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||
cu_seqlens_k_start * stride_vk)
|
||||
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
@ -481,9 +482,9 @@ def attn_fwd(
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
batch_philox_offset = (
|
||||
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
||||
)
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
@ -578,8 +579,9 @@ def attn_fwd(
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||
(0, n_full_blocks))
|
||||
encoded_softmax_block_ptr = tl.advance(
|
||||
encoded_softmax_block_ptr, (0, n_full_blocks)
|
||||
)
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
@ -626,12 +628,11 @@ def attn_fwd(
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||
causal_start_idx,
|
||||
dtype=tl.int32)
|
||||
out_mask_boundary = tl.full(
|
||||
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
||||
)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = (mask_m_offsets[:, None] >=
|
||||
out_mask_boundary[None, :])
|
||||
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
||||
z = 0.0
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
@ -649,8 +650,7 @@ def attn_fwd(
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||
off_h_q * stride_oh)
|
||||
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
|
@ -10,7 +10,9 @@ else:
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
|
Loading…
Reference in New Issue
Block a user