This commit is contained in:
fxmarty 2024-05-16 14:46:47 +00:00
parent 0812e3bdc9
commit 265c76d328
33 changed files with 123 additions and 84 deletions

View File

@ -425,4 +425,4 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv integration-tests

View File

@ -35,4 +35,4 @@ By default, as its performances have experimentally been better, Triton implemen
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints. * Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Kernel for sliding window attention (Mistral) * Kernel for sliding window attention (Mistral)

View File

@ -11,7 +11,7 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{ return _Float16_2{
_Float16_2{static_cast<_Float16>(1.0f), _Float16_2{static_cast<_Float16>(1.0f),
static_cast<_Float16>(1.0f)} / x.data}; static_cast<_Float16>(1.0f)} / x.data};
} }

View File

@ -8,6 +8,7 @@ if SYSTEM == "rocm":
except Exception as e: except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class FastLinear(torch.nn.Module): class FastLinear(torch.nn.Module):
def __init__( def __init__(
self, self,

View File

@ -47,7 +47,7 @@ class BLOOMSharded(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -197,7 +197,9 @@ class LlamaMLP(nn.Module):
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate=( approximate=(
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" "tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
), ),
) )
) )
@ -229,7 +231,11 @@ class LlamaMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1: if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
):
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],
self.intermediate_size, self.intermediate_size,

View File

@ -266,7 +266,9 @@ class MistralMLP(nn.Module):
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate=( approximate=(
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" "tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
), ),
) )
) )
@ -289,7 +291,11 @@ class MistralMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1: if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
):
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],
self.intermediate_size, self.intermediate_size,

View File

@ -64,6 +64,7 @@ elif SYSTEM == "rocm":
else: else:
raise RuntimeError(f"Unsupported system {SYSTEM}") raise RuntimeError(f"Unsupported system {SYSTEM}")
@dataclass @dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast): class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None image_hidden_states: Optional[torch.FloatTensor] = None

View File

@ -831,28 +831,40 @@ class FlashCausalLM(Model):
torch.cuda.tunable.tuning_enable(True) torch.cuda.tunable.tuning_enable(True)
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False): if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")] tuning_sequences = [
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else: else:
tuning_sequences = [1, 2, 4, 8, 16, 32] tuning_sequences = [1, 2, 4, 8, 16, 32]
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.") tunableop_filepath = os.path.join(
"/data",
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)
logger.info(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}."
)
if os.path.isfile(tunableop_filepath): if os.path.isfile(tunableop_filepath):
logger.info(f"The file {tunableop_filepath} already exists and will be reused.") logger.info(
f"The file {tunableop_filepath} already exists and will be reused."
)
torch.cuda.tunable.read_file(tunableop_filepath) torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs("/data", exist_ok=True) os.makedirs("/data", exist_ok=True)
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}") logger.info(f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
else: else:
logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.") logger.info(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
)
if CUDA_GRAPHS: if CUDA_GRAPHS:
try: try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
@ -877,7 +889,9 @@ class FlashCausalLM(Model):
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32), cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache, kv_cache=get_cache_manager().kv_cache,
block_tables=None, block_tables=None,
input_lengths=None, input_lengths=None,

View File

@ -29,7 +29,7 @@ class FlashCohere(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -31,7 +31,7 @@ class FlashDbrx(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -29,7 +29,7 @@ class FlashGemma(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -319,7 +319,7 @@ class BaseFlashMistral(FlashCausalLM):
tokenizer_class=AutoTokenizer, tokenizer_class=AutoTokenizer,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -30,7 +30,7 @@ class FlashNeoXSharded(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -30,7 +30,7 @@ class FlashPhi(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -35,7 +35,7 @@ class FlashQwen2(BaseFlashMistral):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -31,7 +31,7 @@ class FlashRWSharded(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -34,7 +34,7 @@ class FlashSantacoderSharded(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -34,7 +34,7 @@ class FlashStarcoder2(BaseFlashMistral):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -172,7 +172,7 @@ class GalacticaSharded(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -29,7 +29,7 @@ class GPTNeoxSharded(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -36,7 +36,7 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -413,7 +413,7 @@ class Mamba(Model):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, _rank, world_size = initialize_torch_distributed() self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1: if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")

View File

@ -48,7 +48,7 @@ class MPTSharded(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -27,7 +27,7 @@ class OPTSharded(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -27,7 +27,7 @@ class Phi(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, _rank, _world_size = initialize_torch_distributed() self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")

View File

@ -24,7 +24,7 @@ class SantaCoder(CausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype

View File

@ -30,7 +30,7 @@ class T5Sharded(Seq2SeqLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -187,6 +187,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
total_ns=time.time_ns() - start, total_ns=time.time_ns() - start,
) )
def serve( def serve(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],

View File

@ -64,12 +64,17 @@ if SYSTEM in {"cuda", "rocm"}:
is_sm94 = major == 9 and minor == 4 is_sm94 = major == 9 and minor == 4
if SYSTEM == "rocm": if SYSTEM == "rocm":
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1": if (
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
):
ROCM_USE_FLASH_ATTN_V2_TRITON = True ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.") logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else: else:
ROCM_USE_FLASH_ATTN_V2_CK = True ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") logger.info(
"ROCm: using Flash Attention 2 Composable Kernel implementation."
)
try: try:
try: try:
@ -158,6 +163,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
) )
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
def attention( def attention(
q, q,
k, k,
@ -192,8 +198,9 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
False, False,
None, None,
) )
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
def attention( def attention(
q, q,
k, k,
@ -217,6 +224,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
softmax_scale, softmax_scale,
) )
return output return output
elif HAS_FLASH_ATTN: elif HAS_FLASH_ATTN:
def attention( def attention(

View File

@ -46,16 +46,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
@triton.jit @triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, rng_offsets = dropout_offsets(
stride).to(tl.uint32) philox_seed, philox_offset, dropout_p, m, n, stride
).to(tl.uint32)
# TODO: use tl.randint for better performance # TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets) return tl.rand(philox_seed, rng_offsets)
@triton.jit @triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
stride)
rng_keep = rng_output > dropout_p rng_keep = rng_output > dropout_p
return rng_keep return rng_keep
@ -65,9 +65,9 @@ def load_fn(block_ptr, first, second, pad):
if first and second: if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first: elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second: elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else: else:
tensor = tl.load(block_ptr) tensor = tl.load(block_ptr)
return tensor return tensor
@ -133,9 +133,7 @@ def _attn_fwd_inner(
# if not is_modulo_mn. last step might get wasted but that is okay. # if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case. # check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
actual_seqlen_k,
dtype=tl.int32)
size_n = start_n + OFFS_N[None, :] size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None] mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf")) qk = tl.where(mask, qk, float("-inf"))
@ -146,8 +144,9 @@ def _attn_fwd_inner(
# -- compute qk ---- # -- compute qk ----
qk += tl.dot(q, k) qk += tl.dot(q, k)
if bias_ptr is not None: if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS bias = load_fn(
and (n_extra_tokens != 0), "zero") bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
)
# While bias is added after multiplying qk with sm_scale, our # While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional # optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with. # scale factor of log2(e) which we must also multiply the bias with.
@ -159,9 +158,12 @@ def _attn_fwd_inner(
# CAVEAT: Must update l_ij before applying dropout # CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT: if ENABLE_DROPOUT:
philox_offset = (batch_philox_offset + philox_offset = (
start_m * BLOCK_M * actual_seqlen_k + start_n - batch_philox_offset
BLOCK_N) + start_m * BLOCK_M * actual_seqlen_k
+ start_n
- BLOCK_N
)
keep = dropout_mask( keep = dropout_mask(
philox_seed, philox_seed,
philox_offset, philox_offset,
@ -173,8 +175,7 @@ def _attn_fwd_inner(
if RETURN_ENCODED_SOFTMAX: if RETURN_ENCODED_SOFTMAX:
tl.store( tl.store(
encoded_softmax_block_ptr, encoded_softmax_block_ptr,
tl.where(keep, p, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
-p).to(encoded_softmax_block_ptr.type.element_ty),
) )
p = tl.where(keep, p, 0.0) p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX: elif RETURN_ENCODED_SOFTMAX:
@ -202,8 +203,9 @@ def _attn_fwd_inner(
if bias_ptr is not None: if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX: if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, encoded_softmax_block_ptr = tl.advance(
(0, BLOCK_N)) encoded_softmax_block_ptr, (0, BLOCK_N)
)
return acc, l_i, m_i return acc, l_i, m_i
@ -341,7 +343,7 @@ def attn_fwd(
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
HQ: tl.constexpr, HQ: tl.constexpr,
HK:tl.constexpr, HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
@ -392,15 +394,17 @@ def attn_fwd(
# This captures the decrease in n_blocks if we have a rectangular attn # This captures the decrease in n_blocks if we have a rectangular attn
# matrix # matrix
n_blocks_seqlen = cdiv_fn( n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
)
# This is what adjusts the block_max for the current WG, only # This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen) n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is # If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early. # part of the blocks that are all 0. We exit early.
if n_blocks <= 0: if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + o_offset = (
off_h_q * stride_oh) off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
)
O_block_ptr = tl.make_block_ptr( O_block_ptr = tl.make_block_ptr(
base=Out + o_offset, base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL), shape=(seqlen_q, BLOCK_DMODEL),
@ -436,11 +440,10 @@ def attn_fwd(
n_extra_tokens = BLOCK_N - seqlen_k n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N: elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel. # Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr( Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset, base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
@ -449,8 +452,7 @@ def attn_fwd(
block_shape=(BLOCK_M, BLOCK_DMODEL), block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0), order=(1, 0),
) )
k_offset = (off_z * stride_kz + off_h_k * stride_kh + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
cu_seqlens_k_start * stride_kn)
K_block_ptr = tl.make_block_ptr( K_block_ptr = tl.make_block_ptr(
base=K + k_offset, base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
@ -459,8 +461,7 @@ def attn_fwd(
block_shape=(BLOCK_DMODEL, BLOCK_N), block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1), order=(0, 1),
) )
v_offset = (off_z * stride_vz + off_h_k * stride_vh + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
cu_seqlens_k_start * stride_vk)
V_block_ptr = tl.make_block_ptr( V_block_ptr = tl.make_block_ptr(
base=V + v_offset, base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
@ -481,9 +482,9 @@ def attn_fwd(
else: else:
bias_ptr = None bias_ptr = None
if ENABLE_DROPOUT: if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \ batch_philox_offset = (
+ (off_z * HQ + off_h_q) \ philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
* seqlen_q * seqlen_k )
else: else:
batch_philox_offset = 0 batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout. # We can ask to return the dropout mask without actually doing any dropout.
@ -578,8 +579,9 @@ def attn_fwd(
if bias_ptr is not None: if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX: if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, encoded_softmax_block_ptr = tl.advance(
(0, n_full_blocks)) encoded_softmax_block_ptr, (0, n_full_blocks)
)
acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i = _attn_fwd_inner(
acc, acc,
l_i, l_i,
@ -626,12 +628,11 @@ def attn_fwd(
acc = acc.to(Out.type.element_ty) acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102 if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL, ), out_mask_boundary = tl.full(
causal_start_idx, (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
dtype=tl.int32) )
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >= out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
out_mask_boundary[None, :])
z = 0.0 z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE # write back LSE
@ -649,8 +650,7 @@ def attn_fwd(
# tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O # write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr( O_block_ptr = tl.make_block_ptr(
base=Out + o_offset, base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
@ -813,4 +813,4 @@ class _attention(torch.autograd.Function):
return o, encoded_softmax return o, encoded_softmax
triton_attention = _attention.apply triton_attention = _attention.apply

View File

@ -10,7 +10,9 @@ else:
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops from vllm._C import ops
except Exception as e: except Exception as e:
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}") raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache( def reshape_and_cache(