From 107fcfe9b67d5817d3af63f26d834ec7aa960844 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:34:55 +0200 Subject: [PATCH] feat(server): flash attention v2 --- server/Makefile-flash-att | 2 +- server/text_generation_server/models/__init__.py | 5 ++--- .../models/custom_modeling/flash_llama_modeling.py | 4 ++-- .../models/custom_modeling/flash_neox_modeling.py | 4 ++-- .../models/custom_modeling/flash_rw_modeling.py | 6 +++--- .../models/custom_modeling/flash_santacoder_modeling.py | 4 ++-- 6 files changed, 12 insertions(+), 13 deletions(-) diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index bc1d37ef..b779482f 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,4 +1,4 @@ -flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec +flash_att_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc flash-attention: # Clone flash attention diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fd97f8b1..8f843b0e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -57,14 +57,13 @@ try: raise ImportError("CUDA is not available") major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 is_sm90 = major == 9 and minor == 0 - supported = is_sm75 or is_sm8x or is_sm90 + supported = is_sm8x or is_sm90 if not supported: FLASH_ATT_ERROR_MESSAGE = ( - "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. " + "{} requires a CUDA device with capability > 8.0 or 9.0. " "No compatible CUDA device found." ) raise ImportError( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d9f3c7b8..b6bf0a4b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,7 +26,7 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda +import flash_attn_2_cuda import dropout_layer_norm # vllm imports @@ -164,7 +164,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( + flash_attn_2_cuda.varlen_fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b2dce226..5ce80be6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -28,7 +28,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda +import flash_attn_2_cuda # vllm imports import vllm_cache_ops @@ -153,7 +153,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( + flash_attn_2_cuda.varlen_fwd( qkv[:, 0], qkv[:, 1], qkv[:, 2], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index acac2744..051e0c66 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda +import flash_attn_2_cuda # vllm imports import vllm_cache_ops @@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module): kv = kv.expand(-1, 2, self.num_heads, self.head_size) # flash attention - flash_attn_cuda.fwd( + flash_attn_2_cuda.varlen_fwd( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -322,7 +322,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # flash attention - flash_attn_cuda.fwd( + flash_attn_2_cuda.varlen_fwd( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a19623a5..925bd23c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,7 +6,7 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda +import flash_attn_2_cuda # vllm imports import vllm_cache_ops @@ -275,7 +275,7 @@ class FlashMQAttention(torch.nn.Module): key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) # flash attention - flash_attn_cuda.fwd( + flash_attn_2_cuda.varlen_fwd( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1),