This commit is contained in:
Mohit Sharma 2024-08-05 10:12:46 +00:00 committed by AMD
parent 47447ef017
commit 0ad78d20a5
5 changed files with 19 additions and 9 deletions

View File

@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \
# Install flash-attention, torch dependencies # Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir RUN pip install numpy einops ninja --no-cache-dir
RUN conda install intel::mkl-static intel::mkl-include RUN conda install mkl-static mkl-include
RUN pip uninstall -y triton && \ RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \ cd triton/python && \

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130
build-flash-attention-v2-cuda: build-flash-attention-v2-cuda:
pip install -U packaging wheel pip install -U packaging wheel

View File

@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b
build-vllm-cuda: build-vllm-cuda:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm: build-vllm-rocm:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \ git clone https://github.com/mht-sharma/vllm.git vllm; \
fi fi
cd vllm && git fetch && git checkout $(commit_rocm) && \ cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

View File

@ -14,7 +14,7 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true",
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
try: try:
from vllm._C import cache_ops import vllm._custom_ops as ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
@ -33,9 +33,7 @@ def reshape_and_cache(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
cache_ops.reshape_and_cache( ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
@ -78,7 +76,7 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
from vllm._C import ops import vllm._custom_ops as ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
@ -180,6 +178,7 @@ if ENGINE == "ck":
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
softcap=0.0,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
@ -194,12 +193,19 @@ if ENGINE == "ck":
out, out,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
None,
None,
None,
None,
max_s, max_s,
max_s, max_s,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
causal, causal,
window_size_left,
0,
softcap,
False, False,
None, None,
) )

View File

@ -313,11 +313,15 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored. # TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize self.quantize = config.quantize
self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
and not self.quantize and not self.quantize
): ):
out = torch.empty( out = torch.empty(