From 12590fdccebb34f39fb85b7dae29b80fade2b6b0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 23 Oct 2023 12:29:25 +0200 Subject: [PATCH] feat: paged attention v2 (#1183) --- server/Makefile-flash-att-v2 | 2 +- server/Makefile-vllm | 4 +- .../custom_modeling/flash_llama_modeling.py | 15 +-- .../custom_modeling/flash_mistral_modeling.py | 14 +-- .../custom_modeling/flash_neox_modeling.py | 14 +-- .../custom_modeling/flash_rw_modeling.py | 23 ++-- .../flash_santacoder_modeling.py | 15 +-- .../utils/paged_attention.py | 100 ++++++++++++++++++ 8 files changed, 126 insertions(+), 61 deletions(-) create mode 100644 server/text_generation_server/utils/paged_attention.py diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index cdea84311..583437b2e 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c +flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 flash-attention-v2: # Clone flash attention diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2e965da01..c601e452a 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,8 +1,8 @@ -vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78 +vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc vllm: # Clone vllm - git clone https://github.com/OlivierDehaene/vllm.git + git clone https://github.com/vllm-project/vllm.git build-vllm: vllm cd vllm && git fetch && git checkout $(vllm_commit) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7c743a880..69608e1c5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -29,11 +29,7 @@ from typing import Optional, List, Tuple # Flash attention imports import dropout_layer_norm -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - -from text_generation_server.utils.flash_attn import attention +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 77b7f230a..2d731406b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -29,10 +29,7 @@ from typing import Optional, List, Tuple # Flash attention imports import dropout_layer_norm -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9dc374dfc..af4ba96b9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots ) @@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], @@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, qkv[:, 0], kv_cache[0], @@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 8419fa4f4..00f953a64 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) @@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( kv[:, :, 0].contiguous(), kv[:, :, 1].contiguous(), kv_cache[0], @@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), @@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, num_groups, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2dd0a5ee4..c3c7617a3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,10 +5,7 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -# vllm imports -import vllm_cache_ops -import vllm_attention_ops - +from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -18,7 +15,6 @@ from text_generation_server.utils.layers import ( FastLayerNorm, get_linear, ) -from safetensors import SafetensorError def load_multi_mqa( @@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - vllm_cache_ops.reshape_and_cache( + paged_attention.reshape_and_cache( key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), @@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - # kv_cache[1] => [num_blocks, 1, head_size, block_size] - block_size = kv_cache[1].shape[3] - vllm_attention_ops.single_query_cached_kv_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module): self.softmax_scale, block_tables, input_lengths, - block_size, max_s, ) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py new file mode 100644 index 000000000..57a595993 --- /dev/null +++ b/server/text_generation_server/utils/paged_attention.py @@ -0,0 +1,100 @@ +import torch + +# vllm imports +from vllm import cache_ops +from vllm import attention_ops + +_PARTITION_SIZE = 512 + + +def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, + slots: torch.Tensor): + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) + + +def attention( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + kv_head_mapping: torch.Tensor, + softmax_scale: float, + block_tables: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, +): + # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py + # Copyright 2023 The vLLM team. All rights + # reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + + # value_cache => [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (max_s + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + if use_v1: + attention_ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + attention_ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + )