From ff0505e7f967b063de3babd70d94a278e7dfed49 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 4 Sep 2024 05:46:28 +0000 Subject: [PATCH] added custom PA --- Dockerfile_amd | 46 +++++++--- server/Makefile-flash-att-v2 | 2 +- server/Makefile-vllm | 2 +- .../layers/attention/cuda.py | 1 + .../layers/attention/ipex.py | 1 + .../layers/attention/rocm.py | 87 ++++++++++++++----- .../custom_modeling/flash_cohere_modeling.py | 1 + .../custom_modeling/flash_dbrx_modeling.py | 1 + .../flash_deepseek_v2_modeling.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 1 + .../custom_modeling/flash_gemma_modeling.py | 1 + .../custom_modeling/flash_gpt2_modeling.py | 1 + .../custom_modeling/flash_llama_modeling.py | 3 + .../custom_modeling/flash_mistral_modeling.py | 2 + .../custom_modeling/flash_mixtral_modeling.py | 1 + .../custom_modeling/flash_neox_modeling.py | 1 + .../custom_modeling/flash_phi_modeling.py | 1 + .../custom_modeling/flash_qwen2_modeling.py | 1 + .../custom_modeling/flash_rw_modeling.py | 2 + .../flash_santacoder_modeling.py | 1 + .../flash_starcoder2_modeling.py | 1 + .../models/flash_causal_lm.py | 4 +- 22 files changed, 128 insertions(+), 35 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index d8bb8c47..fd612af5 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -39,7 +39,7 @@ COPY launcher launcher RUN cargo build --profile release-opt # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base +FROM rocm/dev-ubuntu-22.04:6.2 AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -48,23 +48,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ git \ make \ + libmsgpack-dev \ libssl-dev \ + llvm-dev \ g++ \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev \ - hipblaslt-dev \ + hipcub-dev \ rocblas-dev \ hiprand-dev \ + hipfft-dev \ rocrand-dev \ miopen-hip-dev \ - hipfft-dev \ - hipcub-dev \ hipsolver-dev \ rccl-dev \ cmake \ - python3-dev && \ + python3-dev \ + python3-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml @@ -74,7 +76,30 @@ ARG ROCM_VERSION='6.0.2' ARG PYTHON_VERSION='3.10.10' # Automatically set by buildx ARG TARGETPLATFORM -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH + +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \ + && chmod +x cmake-3.30.2-linux-x86_64.sh \ + && ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \ + && rm cmake-3.30.2-linux-x86_64.sh + +RUN pip install joblib msgpack + +# Install HIPBLASLt +ARG HIPBLASLT_BRANCH="6f65c6e" +RUN git clone https://github.com/ROCm/hipBLASLt \ + && cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \ + && cd build/release \ + && make package +RUN dpkg -i hipBLASLt/build/release/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; + # && cd .. \ + # && rm -rf hipBLASLt # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba @@ -98,15 +123,15 @@ RUN pip uninstall -y triton && \ cd triton/python && \ pip install . +ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27" RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ - cd pytorch && git fetch --depth 1 origin da320214e66b5af0f7db8fd18a64dbb519d17b27 && \ - git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 && \ + cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \ + git checkout ${PYTORCH_COMMIT} && \ + git submodule update --init --recursive && \ pip install -r requirements.txt --no-cache-dir - ARG _GLIBCXX_USE_CXX11_ABI="1" ARG CMAKE_PREFIX_PATH="/opt/conda" -ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" ARG BUILD_CAFFE2="0" \ BUILD_CAFFE2_OPS="0" \ USE_CUDA="0" \ @@ -221,4 +246,3 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] -CMD ["--json-output"] diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 03527329..74293d9c 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130 +flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28) build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/Makefile-vllm b/server/Makefile-vllm index bf4a1498..18dcc4a0 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b -commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b +commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc..c623b7f9 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -42,6 +42,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + num_kv_heads: int, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index e0956b26..f7aada34 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -58,6 +58,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + num_kv_heads: int, ): out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index da8a4bcd..bd033017 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -8,11 +8,17 @@ from loguru import logger major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 -_PARTITION_SIZE = 512 + +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_CUSTOM = 256 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" +custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0" +if custom_attn_available: + from vllm._custom_C import paged_attention_custom + try: import vllm._custom_ops as ops except Exception as e: @@ -45,6 +51,7 @@ def paged_attention( block_tables: torch.Tensor, input_lengths: Seqlen, max_s: int, + num_kv_heads: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -66,6 +73,22 @@ def paged_attention( # value_cache => [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + + gqa_ratio = num_heads // num_kv_heads + use_custom = ( + custom_attn_available + and (query.dtype == torch.half or query.dtype == torch.bfloat16) + and (head_size == 128 or head_size == 64) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_s <= 32768 + ) + + if not use_custom: + _PARTITION_SIZE = _PARTITION_SIZE_V1V2 + else: + _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = input_lengths.input_lengths @@ -78,7 +101,11 @@ def paged_attention( # to parallelize. import vllm._custom_ops as ops - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = ( + max_s <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512) + and not use_custom + ) if use_v1: ops.paged_attention_v1( out, @@ -110,24 +137,44 @@ def paged_attention( ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) + if not use_custom: + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + paged_attention_custom( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + ) + return out diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9..46022854 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -313,6 +313,7 @@ class FlashCohereAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1ef..bc9e8f15 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -352,6 +352,7 @@ class DbrxAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c2..d02d6cd3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -380,6 +380,7 @@ class DeepseekV2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) # Remove padding. @@ -424,6 +425,7 @@ class DeepseekV2MLP(nn.Module): def forward(self, hidden_states: torch.Tensor, reduce: bool = True): if ( SYSTEM == "rocm" + and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and not self.quantize diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index de86f514..76d8e8ba 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -256,6 +256,7 @@ class FlashGemma2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, softcap=self.softcap, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadb..42ae24f3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -248,6 +248,7 @@ class FlashGemmaAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8c..8c9dc6d6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -247,6 +247,7 @@ class FlashGPT2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 56d88956..9c4cb64a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -235,6 +235,7 @@ class FlashLlamaAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj( @@ -318,6 +319,7 @@ class LlamaMLP(nn.Module): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" + and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and self.hidden_size @@ -557,6 +559,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.model( inputs_embeds, position_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3..e4ba0295 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -235,6 +235,7 @@ class MistralAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj( @@ -300,6 +301,7 @@ class MistralMLP(nn.Module): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" + and hidden_states.dtype == torch.float16 and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and not self.quantize diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c..c84c99c6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -292,6 +292,7 @@ class MixtralAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b1b03ad7..103c84a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -180,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e18348..e067e1d8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -209,6 +209,7 @@ class FlashPhiAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85d..8b3b8322 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -153,6 +153,7 @@ class Qwen2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 708641e7..db4c7e7e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -341,6 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.dense( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c2676782..2d92f3ff 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -308,6 +308,7 @@ class FlashMQAttention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb89..534a3792 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -258,6 +258,7 @@ class Starcoder2Attention(torch.nn.Module): block_tables, input_lengths, max_s, + self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 174bba65..35388f49 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1007,12 +1007,12 @@ class FlashCausalLM(Model): else: self.kv_cache = [ ( - torch.empty( + torch.zeros( (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), dtype=dtype, device=device, ), - torch.empty( + torch.zeros( (num_blocks, num_heads, head_size, BLOCK_SIZE), dtype=dtype, device=device,