From 91f55ea2b5edf5d83d9fdaa9ed85681f7831cb27 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 31 May 2024 10:16:30 +0000 Subject: [PATCH] Removing flash decoding part so it gets merged. --- router/src/infer.rs | 12 +- .../layers/attention/cuda.py | 127 ++++++------------ .../layers/attention/rocm.py | 127 ++++++------------ .../layers/attention/xpu.py | 5 +- .../models/cache_manager.py | 50 +++---- .../custom_modeling/flash_cohere_modeling.py | 34 +---- .../custom_modeling/flash_dbrx_modeling.py | 1 - .../custom_modeling/flash_gemma_modeling.py | 1 - .../custom_modeling/flash_gpt2_modeling.py | 1 - .../custom_modeling/flash_llama_modeling.py | 33 +---- .../custom_modeling/flash_mistral_modeling.py | 3 +- .../custom_modeling/flash_mixtral_modeling.py | 1 - .../custom_modeling/flash_neox_modeling.py | 1 - .../custom_modeling/flash_phi_modeling.py | 1 - .../custom_modeling/flash_qwen2_modeling.py | 1 - .../custom_modeling/flash_rw_modeling.py | 10 +- .../flash_santacoder_modeling.py | 1 - .../flash_starcoder2_modeling.py | 1 - .../text_generation_server/models/globals.py | 3 - 19 files changed, 123 insertions(+), 290 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 4167c976..0410de7d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -70,17 +70,7 @@ impl Infer { tokenizer_config: HubTokenizerConfig, processor_config: HubProcessorConfig, ) -> Self { - // Infer shared state - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") - } else { - false - }; - let block_size = if flashdecoding { 256 } else { 16 }; - let block_size = std::env::var("BLOCK_SIZE") - .map(|b| b.parse().unwrap_or(block_size)) - .unwrap_or(block_size); - let queue = Queue::new(requires_padding, block_size, window_size, speculate); + let queue = Queue::new(requires_padding, 16, window_size, speculate); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 508a98dc..583337bd 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,5 @@ import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -22,14 +21,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def paged_attention( @@ -40,8 +32,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -65,94 +56,64 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_DECODING: - max_q = 1 - max_k = max_s - import flash_attn_2_cuda + from vllm._C import ops - flash_attn_2_cuda.varlen_fwd( + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, query, key_cache, value_cache, - out, - cu_seqlen_q, - cu_seqlen_k, - None, - block_tables, - None, - max_q, - max_k, - 0.0, + kv_head_mapping, softmax_scale, - False, - True, - -1, - 0, - False, + block_tables, + input_lengths, + block_size, + max_s, None, + "auto", + 1.0, ) else: - from vllm._C import ops - - use_v1 = max_s <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) try: diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index ed260334..2e2b7ba9 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,6 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -28,14 +27,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: - shape = key_cache.shape - key_cache.view(-1, shape[-2], shape[-1])[slots] = key - value_cache.view(-1, shape[-2], shape[-1])[slots] = value - else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def paged_attention( @@ -46,8 +38,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -71,94 +62,64 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_DECODING: - max_q = 1 - max_k = max_s - import flash_attn_2_cuda + from vllm._C import ops - flash_attn_2_cuda.varlen_fwd( + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + ops.paged_attention_v1( + out, query, key_cache, value_cache, - out, - cu_seqlen_q, - cu_seqlen_k, - None, - block_tables, - None, - max_q, - max_k, - 0.0, + kv_head_mapping, softmax_scale, - False, - True, - -1, - 0, - False, + block_tables, + input_lengths, + block_size, + max_s, None, + "auto", + 1.0, ) else: - from vllm._C import ops - - use_v1 = max_s <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512 + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, ) - if use_v1: - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) try: diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index 050235a5..a716fcdd 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -59,8 +59,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + input_lengths: torch.Tensor, max_s: int, ): query = query.contiguous() @@ -73,7 +72,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - cu_seqlen_q, + input_lengths, block_size, max_s, None, diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index df6b1ade..c7705fe8 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -3,9 +3,8 @@ import torch from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +BLOCK_SIZE: int = 16 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None @@ -31,38 +30,21 @@ class CacheManager: else: x = self.block_size // element_size - if FLASH_DECODING: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, self.block_size, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, self.block_size, num_heads, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - else: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( 0, num_blocks * self.block_size, dtype=torch.int64 diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 088e3062..31109bc9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,7 +30,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -260,9 +259,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - cu_seqlen_q, - cu_seqlen_k, slots, + input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) @@ -314,8 +312,7 @@ class FlashCohereAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ) @@ -389,9 +386,8 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - cu_seqlen_q, - cu_seqlen_k, slots, + input_lengths, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -404,9 +400,8 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - cu_seqlen_q, - cu_seqlen_k, slots, + input_lengths, max_s, ) @@ -469,24 +464,6 @@ class FlashCohereModel(torch.nn.Module): ) residual = None - if cu_seqlen_prefill is None and FLASH_DECODING: - cu_seqlen_q = torch.arange( - input_lengths.shape[0] + 1, - device=input_ids.device, - dtype=torch.int32, - ) - cu_seqlen_k = torch.cat( - [ - torch.zeros( - (1,), device=input_lengths.device, dtype=input_lengths.dtype - ), - input_lengths.cumsum(dim=-1), - ] - ).to(dtype=torch.int32) - else: - cu_seqlen_q = None - cu_seqlen_k = input_lengths - for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, @@ -496,9 +473,8 @@ class FlashCohereModel(torch.nn.Module): cu_seqlen_prefill, kv_cache[i], block_tables, - cu_seqlen_q, - cu_seqlen_k, slots, + input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 0f3a455d..497956e3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -455,7 +455,6 @@ class DbrxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 345950ea..89ca8b5b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -253,7 +253,6 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index b67ac203..52a7c283 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -244,7 +244,6 @@ class FlashGPT2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4feae4cf..c0fa09fd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -134,8 +133,7 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) @@ -178,8 +176,7 @@ class FlashLlamaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ) @@ -280,8 +277,7 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -295,8 +291,7 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ) @@ -363,23 +358,6 @@ class FlashLlamaModel(torch.nn.Module): cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) - if cu_seqlen_prefill is None and FLASH_DECODING: - cu_seqlen_q = torch.arange( - input_lengths.shape[0] + 1, - device=inputs_embeds.device, - dtype=torch.int32, - ) - cu_seqlen_k = torch.cat( - [ - torch.zeros( - (1,), device=input_lengths.device, dtype=input_lengths.dtype - ), - input_lengths.cumsum(dim=-1), - ] - ).to(dtype=torch.int32) - else: - cu_seqlen_q = None - cu_seqlen_k = input_lengths residual = None for i, layer in enumerate(self.layers): @@ -392,8 +370,7 @@ class FlashLlamaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d36d41cd..77a8a384 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,8 +28,8 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( - attention, paged_attention, + attention, reshape_and_cache, ) from text_generation_server.layers import ( @@ -220,7 +220,6 @@ class MistralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 5e2cb2f9..37cd6f3b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -299,7 +299,6 @@ class MixtralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 87d3b09a..59e7bf8b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -176,7 +176,6 @@ class FlashNeoxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f3d5633c..af3206dd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -215,7 +215,6 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index fa815acf..2b035c2e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -176,7 +176,6 @@ class Qwen2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 50962a1a..ab6cf02a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -198,7 +198,9 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) # output attn_output = torch.empty_like(query) @@ -206,7 +208,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + flash_attn.attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -217,7 +219,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention( + paged_attention.attention( attn_output, query, kv_cache[0], @@ -225,7 +227,6 @@ class FlashRWAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) @@ -349,7 +350,6 @@ class FlashRWLargeAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 5cfde3d5..c8397000 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -309,7 +309,6 @@ class FlashMQAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index b0797fa9..37486e9d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -263,7 +263,6 @@ class Starcoder2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 0a1f5da9..11a9f030 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,6 @@ from loguru import logger MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} -if FLASH_DECODING: - logger.info("Using FLASH_DECODING") if cuda_graphs is not None: try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")]