From 1b86d0f31d1265abc9e714d51b7a4498df9075c7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 17 May 2024 08:43:33 +0000 Subject: [PATCH] Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. --- server/Makefile-flash-att-v2 | 4 +- .../models/cache_manager.py | 50 ++++--- .../custom_modeling/flash_llama_modeling.py | 92 +++++++++---- .../custom_modeling/flash_mistral_modeling.py | 104 ++++++++++----- .../text_generation_server/models/globals.py | 4 + .../utils/flash_attn.py | 3 +- .../utils/paged_attention.py | 126 ++++++++++++------ 7 files changed, 260 insertions(+), 123 deletions(-) diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 36ef576ae..bbff00909 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,11 +1,11 @@ -flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 +flash_att_v2_commit_cuda := v2.5.8 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash-attention-v2-cuda: # Clone flash attention pip install -U packaging ninja --no-cache-dir - git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2 build-flash-attention-v2-cuda: flash-attention-v2-cuda cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda) diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index c7705fe86..df6b1adea 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -3,8 +3,9 @@ import torch from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING -BLOCK_SIZE: int = 16 +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None @@ -30,21 +31,38 @@ class CacheManager: else: x = self.block_size // element_size - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, self.block_size, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, self.block_size, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( 0, num_blocks * self.block_size, dtype=torch.int64 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6e23aa2bd..4e0202a89 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,6 +28,7 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -151,38 +152,75 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) - # output tensor attn_output = torch.empty_like(query) - # Prefill - if cu_seqlen_prefill is not None: - # flash attention - flash_attn.attention( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), - attn_output, - cu_seqlen_prefill, - max_s, - self.softmax_scale, - ) - # Decode + if FLASH_DECODING: + # Prefill + kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ + :, 0 + ] + kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ + :, 1 + ] + + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + # torch.select(kv, dim=1, index=0), + # torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], + attn_output, + cu_seqlen_prefill, + block_tables, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + None, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index ef3777dad..d1cd44183 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -214,44 +215,79 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - - paged_attention.reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots - ) - - # output tensor attn_output = torch.empty_like(query) - # Prefill - if cu_seqlen_prefill is not None: - # flash attention - flash_attn.attention( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), - attn_output, - cu_seqlen_prefill, - max_s, - self.softmax_scale, - window_size_left=self.max_past, - ) - # Decode + if FLASH_DECODING: + # Prefill + kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ + :, 0 + ] + kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ + :, 1 + ] + + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + # torch.select(kv, dim=1, index=0), + # torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], + attn_output, + cu_seqlen_prefill, + block_tables, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + None, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index e8a119581..0a1f5da9f 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,9 +1,13 @@ import torch import os +from loguru import logger MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +if FLASH_DECODING: + logger.info("Using FLASH_DECODING") if cuda_graphs is not None: try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 4f5cf10b6..df3b61af3 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -134,6 +134,7 @@ elif HAS_FLASH_ATTN_V2_CUDA: v, out, cu_seqlens, + block_tables, max_s, softmax_scale, window_size_left=-1, @@ -149,7 +150,7 @@ elif HAS_FLASH_ATTN_V2_CUDA: cu_seqlens, cu_seqlens, None, - None, + block_tables, None, max_s, max_s, diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 6cc30e6d5..0d3cccd69 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -85,53 +85,93 @@ def attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, + if FLASH_DECODING: + cu_seqlen_q = torch.arange( + input_lengths.shape[0] + 1, device=query.device, dtype=torch.int32 + ) + cu_seqlen_k = torch.cat( + [ + torch.zeros( + (1,), device=input_lengths.device, dtype=input_lengths.dtype + ), + input_lengths.cumsum(dim=-1), + ] + ).to(dtype=torch.int32) + max_q = 1 + max_k = max_s + import flash_attn_2_cuda + + flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, + out, + cu_seqlen_q, + cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, + softmax_scale, + False, + True, + -1, + 0, + False, None, - "auto", - 1.0, ) else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + from vllm._C import ops - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + )