From a50936061911be155cf498cb933190a675a44b3a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:17:37 +0000 Subject: [PATCH] trying to update to ROCm 6.1 --- Dockerfile_amd | 172 ++++++++++-------- server/Makefile-triton | 8 - server/Makefile-vllm | 1 - .../models/flash_causal_lm.py | 9 +- .../utils/flash_attn.py | 13 +- .../utils/flash_attn_triton.py | 41 ++--- 6 files changed, 128 insertions(+), 116 deletions(-) delete mode 100644 server/Makefile-triton diff --git a/Dockerfile_amd b/Dockerfile_amd index 987dca56..a261ae2a 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -36,7 +36,7 @@ COPY launcher launcher RUN cargo build --release # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.0.2 as base +FROM rocm/dev-ubuntu-22.04:6.1 as base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -86,109 +86,125 @@ RUN chmod +x ~/mambaforge.sh && \ mamba init && \ rm ~/mambaforge.sh -RUN pip install torch numpy --index-url https://download.pytorch.org/whl/rocm6.0 +# Install flash-attention, torch dependencies +RUN pip install numpy einops ninja --no-cache-dir -FROM base AS kernel-builder +RUN conda install intel::mkl-static intel::mkl-include +RUN pip uninstall -y triton && \ + git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ + cd triton/python && \ + pip install . -# Build Triton -FROM kernel-builder as triton-builder -WORKDIR /usr/src +# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && git checkout d05863883b7b61eb5875abcb6cb6b32fa678beeb && pip install -r requirements.txt --no-cache-dir +RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir -COPY server/Makefile-triton Makefile -RUN make build-triton-rocm +ARG _GLIBCXX_USE_CXX11_ABI="1" +ARG CMAKE_PREFIX_PATH="/opt/conda" +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" +ARG BUILD_CAFFE2="0" \ + BUILD_CAFFE2_OPS="0" \ + USE_CUDA="0" \ + USE_ROCM="1" \ + BUILD_TEST="0" \ + USE_FBGEMM="0" \ + USE_NNPACK="0" \ + USE_QNNPACK="0" \ + USE_XNNPACK="0" \ + USE_FLASH_ATTENTION="0" \ + USE_MEM_EFF_ATTENTION="0" -# Build vllm kernels -FROM kernel-builder AS vllm-builder -WORKDIR /usr/src +# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install -COPY server/Makefile-vllm Makefile +# FROM base AS kernel-builder -# Build specific version of vllm -RUN make build-vllm-rocm +# # Build vllm kernels +# FROM kernel-builder AS vllm-builder +# WORKDIR /usr/src -# Build Flash Attention v2 kernels -FROM kernel-builder AS flash-att-v2-builder -WORKDIR /usr/src +# COPY server/Makefile-vllm Makefile -COPY server/Makefile-flash-att-v2 Makefile +# # Build specific version of vllm +# RUN make build-vllm-rocm -# Build specific version of flash attention v2 -RUN make build-flash-attention-v2-rocm +# # Build Flash Attention v2 kernels +# FROM kernel-builder AS flash-att-v2-builder +# WORKDIR /usr/src -# Build Transformers CUDA kernels (gpt-neox and bloom) -FROM kernel-builder as custom-kernels-builder -WORKDIR /usr/src -COPY server/custom_kernels/ . -RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build +# COPY server/Makefile-flash-att-v2 Makefile -# Build exllama kernels -FROM kernel-builder as exllama-kernels-builder -WORKDIR /usr/src -COPY server/exllama_kernels/ . +# # Build specific version of flash attention v2 +# RUN make build-flash-attention-v2-rocm -RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build +# # Build Transformers CUDA kernels (gpt-neox and bloom) +# FROM kernel-builder as custom-kernels-builder +# WORKDIR /usr/src +# COPY server/custom_kernels/ . +# RUN python setup.py build -# Build exllama v2 kernels -FROM kernel-builder as exllamav2-kernels-builder -WORKDIR /usr/src -COPY server/exllamav2_kernels/ . +# # Build exllama kernels +# FROM kernel-builder as exllama-kernels-builder +# WORKDIR /usr/src +# COPY server/exllama_kernels/ . -RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build +# RUN python setup.py build -FROM base as base-copy +# # Build exllama v2 kernels +# FROM kernel-builder as exllamav2-kernels-builder +# WORKDIR /usr/src +# COPY server/exllamav2_kernels/ . -# Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ - HF_HUB_ENABLE_HF_TRANSFER=1 \ - PORT=80 \ - HIP_FORCE_DEV_KERNARG=1 +# RUN python setup.py build -# Copy builds artifacts from triton builder -COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# FROM base as base-copy -# Copy builds artifacts from vllm builder -COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# # Text Generation Inference base env +# ENV HUGGINGFACE_HUB_CACHE=/data \ +# HF_HUB_ENABLE_HF_TRANSFER=1 \ +# PORT=80 \ +# HIP_FORCE_DEV_KERNARG=1 -# Copy build artifacts from flash attention v2 builder -COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# # Copy builds artifacts from vllm builder +# COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# # Copy build artifacts from flash attention v2 builder +# COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from exllama kernels builder -COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# # Copy build artifacts from custom kernels builder +# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# # Copy build artifacts from exllama kernels builder +# COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Install flash-attention dependencies -RUN pip install einops --no-cache-dir +# # Copy build artifacts from exllamav2 kernels builder +# COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Install server -COPY proto proto -COPY server server -COPY server/Makefile server/Makefile -RUN cd server && \ - make gen-server && \ - pip install -r requirements_rocm.txt && \ - pip install ".[accelerate, peft, outlines]" --no-cache-dir +# # Install server +# COPY proto proto +# COPY server server +# COPY server/Makefile server/Makefile +# # pip install -r requirements_rocm.txt && \ +# #pip install ".[accelerate, peft, outlines]" --no-cache-dir -# Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark -# Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router -# Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +# # Install benchmarker +# COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# # Install router +# COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# # Install launcher +# COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher -# AWS Sagemaker compatible image -FROM base-copy as sagemaker -COPY sagemaker-entrypoint.sh entrypoint.sh -RUN chmod +x entrypoint.sh +# RUN cd server && \ +# make gen-server && \ +# pip install -r requirements_rocm.txt -ENTRYPOINT ["./entrypoint.sh"] +# # AWS Sagemaker compatible image +# FROM base-copy as sagemaker +# COPY sagemaker-entrypoint.sh entrypoint.sh +# RUN chmod +x entrypoint.sh -# Final image -FROM base-copy +# ENTRYPOINT ["./entrypoint.sh"] -ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] +# # Final image +# FROM base-copy + +# # ENTRYPOINT ["text-generation-launcher"] +# # CMD ["--json-output"] diff --git a/server/Makefile-triton b/server/Makefile-triton deleted file mode 100644 index 2ab3b69b..00000000 --- a/server/Makefile-triton +++ /dev/null @@ -1,8 +0,0 @@ -triton-rocm: - pip uninstall -y triton - pip install -U ninja cmake wheel packaging --no-cache-dir - git clone https://github.com/ROCm/triton.git triton - -build-triton-rocm: triton-rocm - cd triton && git fetch && git checkout b9e5290de8bf3a79c4e91ceed7e61b3c8d041b30 - cd triton/python && python setup.py build diff --git a/server/Makefile-vllm b/server/Makefile-vllm index cfb659df..af8966e2 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -18,7 +18,6 @@ vllm-rocm: build-vllm-rocm: vllm-rocm cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 - cd vllm && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install install-vllm-rocm: build-vllm-rocm diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6e14de81..443feb37 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -768,7 +768,8 @@ class FlashCausalLM(Model): max_s = max_bt * get_cache_manager().block_size if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") + torch.cuda.tunable.tuning_enable(False) + _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( @@ -824,10 +825,16 @@ class FlashCausalLM(Model): logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): + torch.cuda.tunable.tuning_enable(True) + + logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") total_seqlens = list(range(2)) for seqlen in total_seqlens: logger.info(f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen, max_s, max_bt) + torch.cuda.tunable.write_file() + torch.cuda.tunable.tuning_enable(False) return int(num_blocks * BLOCK_SIZE) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index c4a70ce5..ed128fd3 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -211,12 +211,13 @@ def attention( # NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not # repeating. # TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like - num_heads = q.shape[1] - num_kv_heads = k.shape[1] - if num_kv_heads != num_heads: - # Interleave for MQA workaround. - k = repeat_kv(k, num_heads // num_kv_heads) - v = repeat_kv(v, num_heads // num_kv_heads) + + # num_heads = q.shape[1] + # num_kv_heads = k.shape[1] + # if num_kv_heads != num_heads: + # # Interleave for MQA workaround. + # k = repeat_kv(k, num_heads // num_kv_heads) + # v = repeat_kv(v, num_heads // num_kv_heads) output, _ = triton_attention( q, diff --git a/server/text_generation_server/utils/flash_attn_triton.py b/server/text_generation_server/utils/flash_attn_triton.py index df757be1..8378c06e 100644 --- a/server/text_generation_server/utils/flash_attn_triton.py +++ b/server/text_generation_server/utils/flash_attn_triton.py @@ -293,7 +293,7 @@ def _attn_fwd_inner( num_warps=4, ), ], - key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], + key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], ) @triton.jit def attn_fwd( @@ -330,8 +330,8 @@ def attn_fwd( philox_seed, philox_offset_base, encoded_softmax, - hq, - hk, + HQ: tl.constexpr, + HK:tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, @@ -414,14 +414,19 @@ def attn_fwd( # TODO: Should dropout and return encoded softmax be handled here? return - is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. q_offset = (off_z * stride_qz + off_h_q * stride_qh + @@ -467,7 +472,7 @@ def attn_fwd( bias_ptr = None if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base \ - + (off_z * hq + off_h_q) \ + + (off_z * HQ + off_h_q) \ * seqlen_q * seqlen_k else: batch_philox_offset = 0 @@ -494,7 +499,7 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") q = (q * qk_scale).to(Q_block_ptr.type.element_ty) # Here we compute how many full and masked blocks we have. @@ -549,7 +554,7 @@ def attn_fwd( False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, - padded_head, + PADDED_HEAD, ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -595,7 +600,7 @@ def attn_fwd( True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, - padded_head, + PADDED_HEAD, ) # epilogue acc = acc / l_i[:, None] @@ -729,16 +734,8 @@ class _attention(torch.autograd.Function): o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} - if head_size not in unpadded_head_dims: - padded_d_model = None - for i in unpadded_head_dims: - if i > head_size: - padded_d_model = i - break - assert padded_d_model is not None - else: - padded_d_model = head_size + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) grid = lambda META: ( triton.cdiv(max_seqlens_q, META["BLOCK_M"]), @@ -781,8 +778,8 @@ class _attention(torch.autograd.Function): philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - hq=nheads_q, - hk=nheads_k, + HQ=nheads_q, + HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k,