From 0ad78d20a57d87a12cec9dcad2f6ff8dea1895c2 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 5 Aug 2024 10:12:46 +0000 Subject: [PATCH] style --- Dockerfile_amd | 2 +- server/Makefile-flash-att-v2 | 2 +- server/Makefile-vllm | 4 ++-- .../layers/attention/rocm.py | 16 +++++++++++----- .../custom_modeling/flash_llama_modeling.py | 4 ++++ 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 51231638..514891a8 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \ # Install flash-attention, torch dependencies RUN pip install numpy einops ninja --no-cache-dir -RUN conda install intel::mkl-static intel::mkl-include +RUN conda install mkl-static mkl-include RUN pip uninstall -y triton && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ cd triton/python && \ diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index dbddd0f4..03527329 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 +flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130 build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/Makefile-vllm b/server/Makefile-vllm index f1f80529..bf4a1498 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b -commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 +commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ @@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda build-vllm-rocm: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ - git clone https://github.com/fxmarty/rocm-vllm.git vllm; \ + git clone https://github.com/mht-sharma/vllm.git vllm; \ fi cd vllm && git fetch && git checkout $(commit_rocm) && \ PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69e64162..77ba4c92 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -14,7 +14,7 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", ENGINE = "triton" if use_triton else "ck" try: - from vllm._C import cache_ops + import vllm._custom_ops as ops except Exception as e: raise ImportError( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" @@ -33,9 +33,7 @@ def reshape_and_cache( key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) + ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def paged_attention( @@ -78,7 +76,7 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + import vllm._custom_ops as ops use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: @@ -180,6 +178,7 @@ if ENGINE == "ck": softmax_scale, window_size_left=-1, causal=True, + softcap=0.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -194,12 +193,19 @@ if ENGINE == "ck": out, cu_seqlens, cu_seqlens, + None, + None, + None, + None, max_s, max_s, 0.0, softmax_scale, False, causal, + window_size_left, + 0, + softcap, False, None, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87..56d88956 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -313,11 +313,15 @@ class LlamaMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize + self.hidden_size = config.hidden_size + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and self.hidden_size + != 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed. and not self.quantize ): out = torch.empty(