From 3016e1595f329616a912444aafcc2109b3e1e875 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 18 Apr 2024 23:31:28 +0000 Subject: [PATCH] at last working! --- Dockerfile_amd | 43 ++++++++++++------- server/Makefile-flash-att-v2 | 6 +-- server/Makefile-vllm | 7 +-- .../exllama_kernels/hip_compat.cuh | 5 ++- .../custom_modeling/flash_cohere_modeling.py | 4 +- .../custom_modeling/idefics_modeling.py | 4 +- .../utils/flash_attn.py | 8 +++- server/text_generation_server/utils/layers.py | 13 +++--- 8 files changed, 55 insertions(+), 35 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index fb820116..7d7e1913 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -36,7 +36,7 @@ COPY launcher launcher RUN cargo build --release # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:5.7 as base +FROM rocm/dev-ubuntu-22.04:6.0.2 as base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -50,13 +50,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ - hipblas-dev && \ + hipblas-dev \ + hipblaslt-dev \ + rocblas-dev \ + hiprand-dev \ + rocrand-dev \ + miopen-hip-dev \ + hipfft-dev \ + hipcub-dev \ + hipsolver-dev \ + rccl-dev \ + cmake \ + python3-dev && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 -ARG PYTORCH_VERSION='2.2.0.dev0' -ARG ROCM_VERSION='5.7' +ARG PYTORCH_VERSION='2.3.0' +ARG ROCM_VERSION='6.0.2' ARG PYTHON_VERSION='3.10.10' # Automatically set by buildx ARG TARGETPLATFORM @@ -75,8 +86,8 @@ RUN chmod +x ~/mambaforge.sh && \ mamba init && \ rm ~/mambaforge.sh -# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6. -RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/ +# Install PyTorch 2.3 RC compiled against RoCm 6.0 +RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/rocm6.0 FROM base AS kernel-builder @@ -102,21 +113,21 @@ RUN make build-flash-attention-v2-rocm FROM kernel-builder as custom-kernels-builder WORKDIR /usr/src COPY server/custom_kernels/ . -RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build +RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build # Build exllama kernels FROM kernel-builder as exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . -RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build +RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build # Build exllama v2 kernels FROM kernel-builder as exllamav2-kernels-builder WORKDIR /usr/src COPY server/exllamav2_kernels/ . -RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build +RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build FROM base as base-copy @@ -147,10 +158,8 @@ RUN pip install einops --no-cache-dir COPY proto proto COPY server server COPY server/Makefile server/Makefile -RUN cd server && \ - make gen-server && \ - pip install -r requirements_rocm.txt && \ - pip install ".[accelerate, peft, outlines]" --no-cache-dir + # pip install -r requirements_rocm.txt && \ + #pip install ".[accelerate, peft, outlines]" --no-cache-dir # Install benchmarker COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark @@ -159,6 +168,10 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi # Install launcher COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +RUN cd server && \ + make gen-server && \ + pip install -r requirements_rocm.txt + # AWS Sagemaker compatible image FROM base-copy as sagemaker COPY sagemaker-entrypoint.sh entrypoint.sh @@ -169,5 +182,5 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy -ENTRYPOINT ["text-generation-launcher"] -CMD ["--json-output"] +# ENTRYPOINT ["text-generation-launcher"] +# CMD ["--json-output"] diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 803b3d1f..36ef576a 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 -flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash-attention-v2-cuda: @@ -18,12 +18,12 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda flash-attention-v2-rocm: # Clone flash attention pip install -U packaging ninja --no-cache-dir - git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2 + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 build-flash-attention-v2-rocm: flash-attention-v2-rocm cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) cd flash-attention-v2 && git submodule update --init --recursive - cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build + cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build install-flash-attention-v2-rocm: build-flash-attention-v2-rocm cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install diff --git a/server/Makefile-vllm b/server/Makefile-vllm index ada484a6..cfb659df 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -14,11 +14,12 @@ install-vllm-cuda: build-vllm-cuda vllm-rocm: # Clone vllm pip install -U ninja packaging --no-cache-dir - git clone https://github.com/fxmarty/vllm-public.git vllm + git clone https://github.com/fxmarty/rocm-vllm.git vllm build-vllm-rocm: vllm-rocm - cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae - cd vllm && python setup.py build + cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 + cd vllm && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch + cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install install-vllm-rocm: build-vllm-rocm pip uninstall vllm -y || true diff --git a/server/exllama_kernels/exllama_kernels/hip_compat.cuh b/server/exllama_kernels/exllama_kernels/hip_compat.cuh index 5e698b1a..d8cbcc49 100644 --- a/server/exllama_kernels/exllama_kernels/hip_compat.cuh +++ b/server/exllama_kernels/exllama_kernels/hip_compat.cuh @@ -10,8 +10,9 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) { } __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; + return _Float16_2{ + _Float16_2{static_cast<_Float16>(1.0f), + static_cast<_Float16>(1.0f)} / x.data}; } #define hrcp __compat_hrcp diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 56d9a966..c6e55fd7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -65,7 +65,7 @@ class CohereRotary(PositionRotaryEmbedding): rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif IS_ROCM_SYSTEM: - from vllm import pos_encoding_ops + from vllm._C import ops # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 @@ -73,7 +73,7 @@ class CohereRotary(PositionRotaryEmbedding): head_size = query.shape[-1] # Inplace operation, updating query and key. - pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False) + ops.rotary_embedding(query, key, head_size, cos, sin, False) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index ee4cdb08..39addd27 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY if IS_CUDA_SYSTEM: import dropout_layer_norm elif IS_ROCM_SYSTEM: - from vllm import layernorm_ops + from vllm._C import ops @dataclass @@ -418,7 +418,7 @@ class IdeficsRMSNorm(nn.Module): hidden_states = hidden_states.reshape(-1, shape[-1]) out = torch.empty_like(hidden_states) - layernorm_ops.rms_norm( + ops.rms_norm( out, hidden_states, self.weight.data, diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 45090c64..94073e7d 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -15,6 +15,7 @@ major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 is_sm90 = major == 9 and minor == 0 +is_sm94 = major == 9 and minor == 4 HAS_FLASH_ATTN = False HAS_FLASH_ATTN_V2_CUDA = False @@ -33,11 +34,16 @@ try: "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" ) - if not (is_sm8x or is_sm90): + if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported for " "Flash Attention V2" ) + elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94): + raise ImportError( + f"AMD GPU with compute capability {major} {minor} is not supported for " + "Flash Attention V2" + ) HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM except ImportError as e: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9cf5c80f..44d593e1 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -793,7 +793,7 @@ try: if IS_CUDA_SYSTEM: import dropout_layer_norm elif IS_ROCM_SYSTEM: - from vllm import layernorm_ops + from vllm._C import ops else: dropout_layer_norm = None @@ -895,7 +895,7 @@ try: residual = hidden_states out = torch.empty_like(hidden_states) - layernorm_ops.rms_norm( + ops.rms_norm( out, hidden_states, self.weight.data, @@ -915,7 +915,7 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif IS_ROCM_SYSTEM: - from vllm import pos_encoding_ops + from vllm._C import ops def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( @@ -970,7 +970,7 @@ try: head_size = query.shape[-1] # Inplace operation, updating query and key. - pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) + ops.rotary_embedding(query, key, head_size, cos, sin, True) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." @@ -1231,6 +1231,5 @@ try: freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) - -except ImportError: - pass +except ImportError as e: + raise e