This commit is contained in:
Mohit Sharma 2024-08-05 10:12:46 +00:00 committed by AMD
parent 47447ef017
commit 0ad78d20a5
5 changed files with 19 additions and 9 deletions

View File

@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
RUN conda install intel::mkl-static intel::mkl-include
RUN conda install mkl-static mkl-include
RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130
build-flash-attention-v2-cuda:
pip install -U packaging wheel

View File

@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

View File

@ -14,7 +14,7 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true",
ENGINE = "triton" if use_triton else "ck"
try:
from vllm._C import cache_ops
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
@ -33,9 +33,7 @@ def reshape_and_cache(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
@ -78,7 +76,7 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
import vllm._custom_ops as ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
@ -180,6 +178,7 @@ if ENGINE == "ck":
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
@ -194,12 +193,19 @@ if ENGINE == "ck":
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)

View File

@ -313,11 +313,15 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
and not self.quantize
):
out = torch.empty(