diff --git a/Dockerfile b/Dockerfile index 66e0091d..7ba4239c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -105,6 +105,16 @@ WORKDIR /usr/src COPY server/custom_kernels/ . +# Build Flash Attention v2 CUDA kernels +FROM kernel-builder as flash-att-v2-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2 + # Build specific version of transformers RUN python setup.py build @@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages diff --git a/server/Makefile b/server/Makefile index d0086928..0dc0b5c9 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ include Makefile-flash-att +include Makefile-flash-att-v2 include Makefile-vllm unit-tests: diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index b779482f..bc1d37ef 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,4 +1,4 @@ -flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc +flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec flash-attention: # Clone flash attention diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 new file mode 100644 index 00000000..5b12129c --- /dev/null +++ b/server/Makefile-flash-att-v2 @@ -0,0 +1,13 @@ +flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc + +flash-attention-v2: + # Clone flash attention + pip install packaging + git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + +build-flash-attention-v2: flash-attention-v2 + cd flash-attention-v2 && git fetch && git checkout $(flash_att_commit) + cd flash-attention-v2 && python setup.py build + +install-flash-attention-v2: build-flash-attention-v2 + cd flash-attention-v2 && python setup.py install \ No newline at end of file diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 8f843b0e..99238a07 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -42,46 +42,19 @@ __all__ = [ "get_model", ] -FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA and Flash Attention kernels to be installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" -) +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +FLASH_ATTENTION = True try: - if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - if not torch.cuda.is_available(): - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA. No compatible CUDA devices found." - ) - raise ImportError("CUDA is not available") + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_neox import FlashNeoXSharded + from text_generation_server.models.flash_llama import ( + FlashLlama, + ) + from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, + ) - major, minor = torch.cuda.get_device_capability() - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - - supported = is_sm8x or is_sm90 - if not supported: - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires a CUDA device with capability > 8.0 or 9.0. " - "No compatible CUDA device found." - ) - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) - - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, - ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, - ) - - FLASH_ATTENTION = True - else: - FLASH_ATTENTION = False except ImportError: logger.opt(exception=True).warning( "Could not import Flash Attention enabled models" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c37a8c7b..d3c719df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,13 +26,13 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_2_cuda import dropout_layer_norm # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -164,21 +164,14 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_2_cuda.varlen_fwd( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0a0bfce7..e7c8ced4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_2_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -153,21 +151,14 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_2_cuda.varlen_fwd( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index eeac2b9e..27bba430 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_2_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -183,21 +181,14 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_2_cuda.varlen_fwd( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, - self.softmax_scale, - False, - True, - False, - None, + self.softmax_scale ) # Decode else: @@ -310,21 +301,14 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_2_cuda.varlen_fwd( + attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, - self.softmax_scale, - False, - True, - False, - None, + self.softmax_scale ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 76dcf1d5..6f5c60fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,13 +5,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_2_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -272,21 +270,14 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_2_cuda.varlen_fwd( + attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - None, ) # Decode else: diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py new file mode 100644 index 00000000..8447ea26 --- /dev/null +++ b/server/text_generation_server/utils/flash_attn.py @@ -0,0 +1,113 @@ +import os +import torch + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") + +if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +is_sm8x = major == 8 and minor >= 0 +is_sm90 = major == 9 and minor == 0 + +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2 = False +try: + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError("Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True +except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError("Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + HAS_FLASH_ATTN = True + + +def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, +): + if HAS_FLASH_ATTN_V2: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + + if HAS_FLASH_ATTN: + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1, -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = k.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \ + .reshape(original_shape[0], -1, original_shape[1], original_shape[2]) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1, -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = v.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \ + .reshape(original_shape[0], -1, original_shape[1], original_shape[2]) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) + + raise NotImplementedError("flash attention is not installed")