From 99b6b5c7957a12ea447a00674c9f1facf8aec743 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 13 Aug 2024 16:22:52 +0200 Subject: [PATCH] Fixing prefix caching. --- backends/v3/src/radix.rs | 2 + .../layers/attention/cuda.py | 4 +- .../{flash_infer.py => flashinfer.py} | 0 .../models/flash_causal_lm.py | 67 +++++++++---------- 4 files changed, 37 insertions(+), 36 deletions(-) rename server/text_generation_server/layers/attention/{flash_infer.py => flashinfer.py} (100%) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 5bac1a31..386720e0 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -89,6 +89,8 @@ impl Allocator for RadixAllocator { let prefix_len = blocks.len(); let suffix_len = tokens - prefix_len as u32; + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + match self.alloc_or_reclaim(suffix_len as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 18cf70d3..b3b7ea4f 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -76,7 +76,7 @@ def paged_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import decode_state + from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( query.contiguous(), @@ -234,7 +234,7 @@ if ATTENTION == "flashinfer": softcap=0.0, ): assert window_size_left == -1, "Windowing is not supported with flash infer" - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, ) diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flashinfer.py similarity index 100% rename from server/text_generation_server/layers/attention/flash_infer.py rename to server/text_generation_server/layers/attention/flashinfer.py diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 82669215..13885f28 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -468,6 +468,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = block_tables_tensor.to(device) prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) + log_master(logger.info, f"Block tables {block_tables}") return cls( batch_id=pb.id, requests=pb.requests, @@ -1028,7 +1029,7 @@ class FlashCausalLM(Model): self.kv_cache = [] if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_prefill_state, create_decode_state, create_prefill_with_paged_kv_state, @@ -1166,7 +1167,7 @@ class FlashCausalLM(Model): self.cuda_graphs[bs]["graph"] = graph if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, ) @@ -1411,7 +1412,6 @@ class FlashCausalLM(Model): ).view(-1) prefix_lens_tensor = ( batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) - + arange_int ).view(-1) # Add Copy the block tables for all members @@ -1452,6 +1452,13 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, @@ -1496,9 +1503,9 @@ class FlashCausalLM(Model): cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][ - : input_lengths.shape[0] - ] = input_lengths # + prefix_lens_tensor + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) with self._forward_context( block_tables=cuda_graph["block_tables"], @@ -1736,7 +1743,7 @@ class FlashCausalLM(Model): left = 0 if n_accepted_ids > 1: - log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.info, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): @@ -1900,40 +1907,32 @@ class FlashCausalLM(Model): if ATTENTION != "flashinfer": return nullcontext() - from text_generation_server.layers.attention.flash_infer import ( + from text_generation_server.layers.attention.flashinfer import ( use_decode_state, - use_prefill_state, use_prefill_with_paged_kv_state, ) # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) if cu_seqlen_prefill is not None: - if True: # has_prefix_lens: - return use_prefill_with_paged_kv_state( - state=( - state if state is not None else self.prefill_with_paged_kv_state - ), - block_tables=block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - prefix_lens=prefix_lens, - ), - cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - ) - else: - return use_prefill_state( - state=state if state is not None else self.prefill_state, - cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - ) + log_master(logger.info, f"Prefix lens {prefix_lens}") + return use_prefill_with_paged_kv_state( + state=( + state if state is not None else self.prefill_with_paged_kv_state + ), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, + cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) else: assert input_lengths_tensor is not None return use_decode_state(