diff --git a/router/src/infer.rs b/router/src/infer.rs index 1447e7567..0410de7df 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -70,7 +70,6 @@ impl Infer { tokenizer_config: HubTokenizerConfig, processor_config: HubProcessorConfig, ) -> Self { - // Infer shared state let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 36ef576ae..bbff00909 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,11 +1,11 @@ -flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 +flash_att_v2_commit_cuda := v2.5.8 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash-attention-v2-cuda: # Clone flash attention pip install -U packaging ninja --no-cache-dir - git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 build-flash-attention-v2-cuda: flash-attention-v2-cuda cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py new file mode 100644 index 000000000..e6cb4edff --- /dev/null +++ b/server/text_generation_server/layers/attention/__init__.py @@ -0,0 +1,13 @@ +from text_generation_server.utils.import_utils import SYSTEM +import os + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") +if SYSTEM == "cuda": + from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "rocm": + from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "xpu": + from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py new file mode 100644 index 000000000..583337bdb --- /dev/null +++ b/server/text_generation_server/layers/attention/cuda.py @@ -0,0 +1,245 @@ +import torch +from text_generation_server.utils.import_utils import SYSTEM + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +_PARTITION_SIZE = 512 + +try: + from vllm._C import cache_ops + from vllm._C import ops +except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + + +def paged_attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + from vllm._C import ops + + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + + +try: + import flash_attn_2_cuda + + V2 = True +except ImportError: + try: + import flash_attn_cuda + + V2 = False + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + + +SUPPORTS_WINDOWING = V2 +if V2: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + None, + None, + None, + max_s, + max_s, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + False, + None, + ) + +else: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) diff --git a/server/text_generation_server/utils/flash_attn_triton.py b/server/text_generation_server/layers/attention/flash_attn_triton.py similarity index 100% rename from server/text_generation_server/utils/flash_attn_triton.py rename to server/text_generation_server/layers/attention/flash_attn_triton.py diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py new file mode 100644 index 000000000..2d3601c84 --- /dev/null +++ b/server/text_generation_server/layers/attention/rocm.py @@ -0,0 +1,295 @@ +import os +import torch +from text_generation_server.utils.import_utils import SYSTEM +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +_PARTITION_SIZE = 512 + +use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} +ENGINE = "triton" if use_triton else "ck" + +try: + from vllm._C import cache_ops + from vllm._C import ops +except Exception as e: + raise ImportError( + f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + + +def paged_attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + from vllm._C import ops + + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + + +if ENGINE != "triton": + try: + import flash_attn_2_cuda + + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError: + try: + import flash_attn_cuda + + ENGINE = "v1" + logger.info("ROCm: using Flash Attention 1") + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: + + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + + +SUPPORTS_WINDOWING = ENGINE != "v1" +if ENGINE == "ck": + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") + if window_size_left != -1: + raise ValueError( + f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + None, + None, + None, + max_s, + max_s, + 0.0, + softmax_scale, + False, + causal, + window_size_left, + 0, + False, + None, + ) + +elif ENGINE == "triton": + from .flash_attn_triton import triton_attention + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + ): + if window_size_left != -1: + raise ValueError( + f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + output, _ = triton_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + causal, + softmax_scale, + ) + return output + +else: + + def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + ): + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py new file mode 100644 index 000000000..d9a096f96 --- /dev/null +++ b/server/text_generation_server/layers/attention/xpu.py @@ -0,0 +1,76 @@ +import intel_extension_for_pytorch as ipex +import torch + +SUPPORTS_WINDOWING = False + + +def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, +): + if window_size_left != -1: + raise ValueError( + f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + return ipex.llm.functional.varlen_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, +): + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + + +def paged_attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + query = query.contiguous() + block_size = value_cache.shape[3] + return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d086f87b5..dbe490395 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -80,15 +80,11 @@ try: from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx - from text_generation_server.utils.flash_attn import ( - HAS_FLASH_ATTN_V2_CUDA, - HAS_FLASH_ATTN_V2_ROCM, - ) + from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") + SUPPORTS_WINDOWING = False FLASH_ATTENTION = False - HAS_FLASH_ATTN_V2_CUDA = False - HAS_FLASH_ATTN_V2_ROCM = False if FLASH_ATTENTION: __all__.append(FlashGPT2) @@ -262,6 +258,7 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, ) -> Model: + global FLASH_ATTENTION if dtype is None: if quantize in ["awq", "exl2", "gptq"]: # These quantizers only work with float16 params. @@ -412,6 +409,12 @@ def get_model( raise RuntimeError( "Sharding is currently not supported with `exl2` quantization" ) + sliding_window = config_dict.get("sliding_window", -1) + if sliding_window != -1 and not SUPPORTS_WINDOWING: + logger.warning( + f"Flash attention is available, but doesn't support windowing which is required by model {model_id}" + ) + FLASH_ATTENTION = False if model_type == MAMBA: return Mamba( @@ -699,11 +702,7 @@ def get_model( if model_type == MISTRAL: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if FLASH_ATTENTION: return FlashMistral( model_id, revision, @@ -726,11 +725,7 @@ def get_model( if model_type == MIXTRAL: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if FLASH_ATTENTION: return FlashMixtral( model_id, revision, @@ -753,11 +748,7 @@ def get_model( if model_type == STARCODER2: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if FLASH_ATTENTION: return FlashStarcoder2( model_id, revision, @@ -781,11 +772,7 @@ def get_model( if model_type == QWEN2: sliding_window = config_dict.get("sliding_window", -1) - if ( - ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) - or HAS_FLASH_ATTN_V2_CUDA - or HAS_FLASH_ATTN_V2_ROCM - ): + if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index bd8b80160..31109bc97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -25,7 +25,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -281,7 +285,7 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -289,7 +293,7 @@ class FlashCohereAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, key, value, @@ -300,7 +304,7 @@ class FlashCohereAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 56bfb9d02..497956e32 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -27,7 +27,11 @@ from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": from vllm.model_executor.layers.fused_moe import fused_moe -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( FastLinear, TensorParallelRowLinear, @@ -424,9 +428,7 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -434,7 +436,7 @@ class DbrxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -445,7 +447,7 @@ class DbrxAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index cff4b5d53..89ca8b5b3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -26,7 +26,11 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -221,9 +225,7 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -231,7 +233,7 @@ class FlashGemmaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -243,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d2599f7a4..52a7c283a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -25,7 +25,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -213,7 +217,7 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -221,7 +225,7 @@ class FlashGPT2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, key, value, @@ -232,7 +236,7 @@ class FlashGPT2Attention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index fa3a78f84..c0fa09fd3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,7 +28,11 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -145,9 +149,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -155,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -166,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 65043dee2..77a8a3845 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -186,7 +190,7 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -196,7 +200,7 @@ class MistralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -208,7 +212,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index be2d6c451..37cd6f3b7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -33,7 +33,11 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from loguru import logger -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( FastLinear, TensorParallelRowLinear, @@ -265,7 +269,7 @@ class MixtralAttention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -275,7 +279,7 @@ class MixtralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -287,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d45cab2e0..59e7bf8b2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,8 +27,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -146,9 +149,7 @@ class FlashNeoxAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) - paged_attention.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(qkv[:, 0]) @@ -156,7 +157,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], @@ -167,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f2efb5386..af3206dd7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -6,7 +6,11 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -185,16 +189,14 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -205,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 3a6d2db52..2b035c2e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -5,7 +5,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -142,7 +146,7 @@ class Qwen2Attention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -152,7 +156,7 @@ class Qwen2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -164,7 +168,7 @@ class Qwen2Attention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fa463a195..d489c3ba2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -15,7 +15,11 @@ from text_generation_server.layers import ( ) from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.utils import flash_attn, paged_attention +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) def load_row(config, prefix: str, weights, bias: bool): @@ -194,9 +198,7 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -204,7 +206,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -215,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], @@ -313,7 +315,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - paged_attention.reshape_and_cache( + reshape_and_cache( kv[:, :, 0].contiguous(), kv[:, :, 1].contiguous(), kv_cache[0], @@ -327,7 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), @@ -338,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index cfa4243f6..c8397000e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,7 +6,11 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -276,7 +280,7 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - paged_attention.reshape_and_cache( + reshape_and_cache( key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -286,7 +290,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), @@ -297,7 +301,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 3e2ce4f97..37486e9db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -26,7 +26,11 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -229,7 +233,7 @@ class Starcoder2Attention(torch.nn.Module): else: kv_to_cache = kv - paged_attention.reshape_and_cache( + reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -239,7 +243,7 @@ class Starcoder2Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -251,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index e8a119581..11a9f030f 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,5 +1,6 @@ import torch import os +from loguru import logger MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py deleted file mode 100644 index 4f5cf10b6..000000000 --- a/server/text_generation_server/utils/flash_attn.py +++ /dev/null @@ -1,293 +0,0 @@ -import os -import torch - -from loguru import logger -import math - -from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM != "xpu": - from text_generation_server.utils.flash_attn_triton import triton_attention - -if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") -HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2_CUDA = False -HAS_FLASH_ATTN_V2_ROCM = False -ROCM_USE_FLASH_ATTN_V2_CK = False -ROCM_USE_FLASH_ATTN_V2_TRITON = False - - -if SYSTEM in {"cuda", "rocm"}: - if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - is_sm94 = major == 9 and minor == 4 - - if SYSTEM == "rocm": - if ( - os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" - or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1" - ): - ROCM_USE_FLASH_ATTN_V2_TRITON = True - logger.info("ROCm: using Flash Attention 2 Triton implementation.") - else: - ROCM_USE_FLASH_ATTN_V2_CK = True - logger.info( - "ROCm: using Flash Attention 2 Composable Kernel implementation." - ) - - try: - try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if SYSTEM == "cuda" and not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94): - raise ImportError( - f"AMD GPU with compute capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda" - HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm" - except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif SYSTEM == "rocm": - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - -if SYSTEM == "xpu": - import intel_extension_for_pytorch as ipex - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - - if window_size_left != -1: - raise ValueError( - f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) - return ipex.llm.functional.varlen_attention( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - None, - ) - -elif HAS_FLASH_ATTN_V2_CUDA: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - None, - None, - None, - max_s, - max_s, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - False, - None, - ) - -elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, - ): - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) - - # RoCm flash API does not take the window_size_left and window_size_right arguments. - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) - -elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=True, - ): - output, _ = triton_attention( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - causal, - softmax_scale, - ) - return output - -elif HAS_FLASH_ATTN: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - return flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - 0, - None, - ) - -else: - raise NotImplementedError("flash attention is not installed") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 40e576460..d79e36c22 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,4 +1,5 @@ import torch +from loguru import logger def is_xpu_available(): @@ -48,3 +49,4 @@ else: empty_cache = noop synchronize = noop get_free_memory = noop +logger.info(f"Detected system {SYSTEM}") diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py deleted file mode 100644 index 6cc30e6d5..000000000 --- a/server/text_generation_server/utils/paged_attention.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -from text_generation_server.utils.import_utils import SYSTEM - -_PARTITION_SIZE = 512 - -if SYSTEM == "xpu": - import intel_extension_for_pytorch as ipex -else: - try: - from vllm._C import cache_ops - from vllm._C import ops - except Exception as e: - raise ImportError( - f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slots: torch.Tensor, -): - if SYSTEM == "xpu": - ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slots - ) - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) - - -def attention( - out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - kv_head_mapping: torch.Tensor, - softmax_scale: float, - block_tables: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, -): - # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py - # Copyright 2023 The vLLM team. All rights - # reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - - # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - if SYSTEM == "xpu": - query = query.contiguous() - return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - ) - - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - )