diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 5b12129c..a7d63356 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc +flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc flash-attention-v2: # Clone flash attention @@ -6,7 +6,7 @@ flash-attention-v2: git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 build-flash-attention-v2: flash-attention-v2 - cd flash-attention-v2 && git fetch && git checkout $(flash_att_commit) + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) cd flash-attention-v2 && python setup.py build install-flash-attention-v2: build-flash-attention-v2 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 99238a07..ffc224cc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -55,10 +55,8 @@ try: FlashSantacoderSharded, ) -except ImportError: - logger.opt(exception=True).warning( - "Could not import Flash Attention enabled models" - ) +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 27bba430..1e9539c4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -188,7 +188,7 @@ class FlashRWAttention(torch.nn.Module): attn_output, cu_seqlen_prefill, max_s, - self.softmax_scale + self.softmax_scale, ) # Decode else: @@ -308,7 +308,7 @@ class FlashRWLargeAttention(torch.nn.Module): attn_output, cu_seqlen_prefill, max_s, - self.softmax_scale + self.softmax_scale, ) # Decode else: diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 8447ea26..c472d1fc 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -1,6 +1,8 @@ import os import torch +from loguru import logger + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -18,10 +20,11 @@ try: try: import flash_attn_2_cuda except ImportError: - raise ImportError("Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" - ) + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) if not (is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported for " @@ -32,26 +35,28 @@ except ImportError as e: try: import flash_attn_cuda except ImportError: - raise ImportError("Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e if not (is_sm75 or is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported" ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, ): if HAS_FLASH_ATTN_V2: return flash_attn_2_cuda.varlen_fwd( @@ -76,21 +81,27 @@ def attention( if k.shape[1] != q.shape[1]: # MQA expand if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1, -1) + k = k.expand(-1, q.shape[1], -1) # Grouped attention reshape else: original_shape = k.shape - k = k.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \ - .reshape(original_shape[0], -1, original_shape[1], original_shape[2]) + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) if v.shape[1] != q.shape[1]: # MQA expand if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1, -1) + v = v.expand(-1, q.shape[1], -1) # Grouped attention reshape else: original_shape = v.shape - v = v.unsqueeze(2).expand(-1, -1, q.shape[1], -1, -1) \ - .reshape(original_shape[0], -1, original_shape[1], original_shape[2]) + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) return flash_attn_cuda.fwd( q,