From b70ae0969f11bae03a3c6194fc8c592a1d8a65b3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 20 Aug 2024 11:15:30 +0200 Subject: [PATCH] Prefix caching (#2402) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prefix caching WIP * Fixing prefix attention. * Fixing flashinfer import. * Fixing black. * Fixing medusa (still wrong outputs, but functional). * Just medusa values now. * Fixing medusa without prefix caching. * Fixing prefix caching. * Medusa requires reshaping. * Removing the logs. * Remove router.nix * Fixup: - Remove logs - Disable VLMs (they do not work) - Disable prefix caching when user wants prefill logprobs. * Update flake.lock --------- Co-authored-by: Danièˆl de Kok --- backends/v3/src/queue.rs | 13 +- backends/v3/src/radix.rs | 1 + flake.lock | 6 +- flake.nix | 1 + .../layers/attention/__init__.py | 7 +- .../layers/attention/cuda.py | 27 +- .../{flash_infer.py => flashinfer.py} | 76 ++++++ .../text_generation_server/layers/medusa.py | 2 + .../custom_modeling/flash_cohere_modeling.py | 2 + .../custom_modeling/flash_dbrx_modeling.py | 2 + .../flash_deepseek_v2_modeling.py | 2 + .../custom_modeling/flash_gemma2_modeling.py | 2 + .../custom_modeling/flash_gemma_modeling.py | 2 + .../custom_modeling/flash_gpt2_modeling.py | 2 + .../custom_modeling/flash_llama_modeling.py | 2 + .../custom_modeling/flash_mistral_modeling.py | 2 + .../custom_modeling/flash_mixtral_modeling.py | 2 + .../custom_modeling/flash_neox_modeling.py | 2 + .../custom_modeling/flash_phi_modeling.py | 2 + .../custom_modeling/flash_qwen2_modeling.py | 2 + .../custom_modeling/flash_rw_modeling.py | 4 + .../flash_santacoder_modeling.py | 2 + .../flash_starcoder2_modeling.py | 2 + .../models/flash_causal_lm.py | 252 ++++++++++++++---- .../text_generation_server/models/globals.py | 6 +- .../models/vlm_causal_lm.py | 89 +++++-- 26 files changed, 405 insertions(+), 107 deletions(-) rename server/text_generation_server/layers/attention/{flash_infer.py => flashinfer.py} (65%) diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 0fb05a982..faa57c113 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -316,10 +316,15 @@ impl State { + self.speculate - 1; - match block_allocator - .allocate(tokens, entry.request.input_ids.clone()) - .await - { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget // Add it back to the front diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index ef963532e..5bac1a31f 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,7 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } + impl Default for RadixTrie { fn default() -> Self { Self::new() diff --git a/flake.lock b/flake.lock index d0c2adbcb..cd5d6d2a6 100644 --- a/flake.lock +++ b/flake.lock @@ -900,11 +900,11 @@ ] }, "locked": { - "lastModified": 1723515680, - "narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", + "lastModified": 1723602049, + "narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", + "rev": "ea0bf33a11a26a62c60123c49d96011da396602c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 6bfe74e85..299e6b3d8 100644 --- a/flake.nix +++ b/flake.nix @@ -84,6 +84,7 @@ grpcio-status grpcio-tools hf-transfer + ipdb loguru mamba-ssm marlin-kernels diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index f9b1715ef..56fc53194 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -6,7 +6,12 @@ from .common import Seqlen if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": - from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .cuda import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "ipex": diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 8703eb94f..b3b7ea4fe 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -76,7 +76,7 @@ def paged_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import decode_state + from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( query.contiguous(), @@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2 if ATTENTION == "flashinfer": def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -231,14 +233,15 @@ if ATTENTION == "flashinfer": causal=True, softcap=0.0, ): - from text_generation_server.layers.attention.flash_infer import prefill_state + assert window_size_left == -1, "Windowing is not supported with flash infer" + from text_generation_server.layers.attention.flashinfer import ( + prefill_with_paged_kv_state, + ) - return prefill_state.get().forward( - q, - k, - v, + return prefill_with_paged_kv_state.get().forward( + q.contiguous(), causal=causal, - window_left=window_size_left, + paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -249,6 +252,8 @@ elif V2: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, @@ -289,6 +294,8 @@ else: q, k, v, + key_cache: torch.Tensor, + value_cache: torch.Tensor, cu_seqlens, max_s, softmax_scale, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flashinfer.py similarity index 65% rename from server/text_generation_server/layers/attention/flash_infer.py rename to server/text_generation_server/layers/attention/flashinfer.py index 56b53b2c9..e1ef62c5b 100644 --- a/server/text_generation_server/layers/attention/flash_infer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con "prefill_state" ) +prefill_with_paged_kv_state: ContextVar[ + flashinfer.BatchPrefillWithPagedKVCacheWrapper +] = ContextVar("prefill_with_paged_kv_state") + decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( "decode_state" ) @@ -24,6 +28,78 @@ def get_workspace(device): return workspace +def create_prefill_with_paged_kv_state( + *, + device: torch.device, +): + """Create a prefill state that uses the KV cache.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_with_paged_kv_state( + *, + state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, + block_tables: torch.Tensor, + cu_seqlens: torch.Tensor, + input_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + if page_size == 1: + last_page_len = torch.ones( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + else: + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = prefill_with_paged_kv_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + paged_kv_indptr=indptr, + paged_kv_indices=block_tables, + paged_kv_last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + page_size=page_size, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_with_paged_kv_state.reset(token) + + def create_prefill_state( *, device: torch.device, diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 7579ccdbd..139c4dc25 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module): ) def forward(self, x): + if not self.heads: + return None speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e02a31d9a..1eb8c6c31 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d3d1d1efc..fc0dca5bf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0905d3c29..b25becd5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 54d212e6c..faf0f3258 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 178efadbe..33738a59d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a19cff8cc..d30b5a0ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module): query, key, value, + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9ea19a87d..3253d2dc0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dda53ff3d..5a150267d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 85431c6c9..ad426ffe7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 67237d5c5..b684e035f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], qkv[:, 1], qkv[:, 2], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9e183480..efe27c137 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 865cc85de..879b8abd7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 10f995a3d..c72a9b90b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c26767822..109304be9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index e562eb896..200d4ef0c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5e2fd20a1..dd4203e06 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,6 +43,7 @@ from text_generation_server.models.globals import ( ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, + PREFIX_CACHING, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch): block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor + # size [b], containing the number of blocks that can be retrieved from the cache + prefix_lens: List[int] + prefix_lens_tensor: torch.Tensor max_seqlen: int @@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + # Prefixes + prefix_ids: List[List[int]] + # All tokens all_input_ids: List[List[int]] all_input_ids_tensor: torch.Tensor @@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] all_input_ids = [] + prefix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 + cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] + prefix_lens = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch): ): tokenized_input = tokenized_input[1:] + orig_input_length = len(tokenized_input) + + if PREFIX_CACHING: + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 + else: + prefix_len = 0 + + prefix_ids.append(tokenized_input[:prefix_len]) + tokenized_input = tokenized_input[prefix_len:] + input_length = len(tokenized_input) input_lengths.append(input_length) @@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + request_position_ids = torch.arange( + prefix_len, orig_input_length, dtype=torch.int32 + ) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs @@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch): # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to blocks. + block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length + + # Tokens that need to be mapped to slots. We don't need slots for the + # cached prefix (if present). + slot_tokens = input_length + max_new_tokens - 1 + speculative_length # blocks and slots can be empty (for example in warmup) if not r.blocks: - needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] @@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch): ] else: request_blocks = r.blocks - request_slots = r.slots + request_slots = r.slots[ + prefix_len: #: orig_input_length + max_new_tokens + speculative_length + ] block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) + + slots.extend(request_slots) + prefix_lens.append(prefix_len) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) + start_slots.append(cumulative_slot_tokens) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + cumulative_slot_tokens, + cumulative_slot_tokens + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) @@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - cumulative_max_length += total_tokens + cumulative_slot_tokens += slot_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch): ) slots = torch.tensor(slots, dtype=torch.int64, device=device) + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] + prefix_lens = [] prefix_offsets = [] read_offsets = [] @@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) + prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) + prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) @@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch): start_slots = [] block_tables = [] + prefix_lens = [] all_input_ids = [] + prefix_ids = [] input_lengths = [] prefix_offsets = [] @@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor + start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) + prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) + prefix_ids.extend(batch.prefix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch): slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, @@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, + prefix_ids=prefix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -970,19 +1028,22 @@ class FlashCausalLM(Model): self.kv_cache = [] if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_prefill_state, create_decode_state, + create_prefill_with_paged_kv_state, ) self.prefill_state = create_prefill_state(device=device) + self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state( + device=device + ) - if not CUDA_GRAPHS: - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) super().__init__( model_id=model_id, @@ -1074,12 +1135,23 @@ class FlashCausalLM(Model): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - block_tables = ( - torch.arange(max_bt, dtype=torch.int32, device=self.device) - .repeat(bs) - .reshape((bs, max_bt)) + input_lengths = [max_s] * bs + prefix_lengths = [0] * bs + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s ) + prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + prefix_lens=prefix_lengths, + ) self.cuda_graphs[bs] = { "input_ids": input_ids, @@ -1087,14 +1159,14 @@ class FlashCausalLM(Model): "kv_cache": self.kv_cache, "block_tables": block_tables, "slots": slots, - "input_lengths": input_lengths, + "input_lengths": input_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths) + input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, ) @@ -1104,7 +1176,7 @@ class FlashCausalLM(Model): last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) state = create_decode_state_cuda_graphs( device=input_ids.device, - block_tables=block_tables.view(-1), + block_tables=block_tables, block_tables_ptr=block_tables_ptr, last_page_len=last_page_len, num_heads=self.num_heads, @@ -1120,7 +1192,10 @@ class FlashCausalLM(Model): block_tables=block_tables, cu_seqlen_prefill=None, input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, state=state, + prefix_lens=prefix_lengths, + prefix_lens_tensor=prefix_lengths_tensor, ): self.model.forward( input_ids=input_ids, @@ -1138,7 +1213,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) + input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1146,7 +1221,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1334,6 +1409,9 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -1354,6 +1432,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1372,10 +1451,20 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths=input_lengths, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, ): input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( @@ -1399,20 +1488,32 @@ class FlashCausalLM(Model): # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) - state = cuda_graph.get("state") with self._forward_context( - block_tables=block_tables, + block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, - input_lengths=input_lengths, - state=state, + input_lengths=batch.input_lengths, + input_lengths_tensor=cuda_graph["input_lengths"], + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + state=cuda_graph.get("state"), ): # Replay the graph cuda_graph["graph"].replay() @@ -1610,6 +1711,7 @@ class FlashCausalLM(Model): batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, + batch.prefix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -1627,6 +1729,7 @@ class FlashCausalLM(Model): read_offset, stopping_criteria, all_input_ids, + prefix_ids, do_sample, seed, top_n_tokens, @@ -1701,18 +1804,18 @@ class FlashCausalLM(Model): out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + request_prefill_logprobs = ( + [float("nan")] * (len(prefix_ids) + 1) + ) + prefill_logprobs[out_start_index : out_end_index - 1] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, + prefix_ids + prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = Tokens( - prefill_token_ids, + prefix_ids + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], @@ -1794,33 +1897,68 @@ class FlashCausalLM(Model): *, block_tables: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], - input_lengths: torch.Tensor, + input_lengths: List[int], + input_lengths_tensor: torch.Tensor, + prefix_lens: List[int], + prefix_lens_tensor: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( use_decode_state, - use_prefill_state, + use_prefill_with_paged_kv_state, ) + # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) + if cu_seqlen_prefill is not None: - return use_prefill_state( - state=state if state is not None else self.prefill_state, + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - ) - else: - assert input_lengths is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths, - block_tables=block_tables.view(-1), + input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, ) + else: + assert input_lengths_tensor is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths_tensor, + block_tables=block_tables, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) + + +def block_tables_to_ragged( + *, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(prefix_lens) + + total_len = sum(input_lengths) + sum(prefix_lens) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + offset = 0 + for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): + seq_len = prefix_len + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index abc354212..d5133f5e2 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,9 +5,8 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) -log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") - +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} +log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( @@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer": else: BLOCK_SIZE = 16 - cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7de54aa44..2ed1a119d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + block_tables_to_ragged, ) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): trust_remote_code: bool, **kwargs, ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) # Replay the graph cuda_graph["graph"].replay()