Fixing prefix caching.

This commit is contained in:
Nicolas Patry 2024-08-13 16:22:52 +02:00
parent b2933b72d0
commit 99b6b5c795
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 37 additions and 36 deletions

View File

@ -89,6 +89,8 @@ impl Allocator for RadixAllocator {
let prefix_len = blocks.len(); let prefix_len = blocks.len();
let suffix_len = tokens - prefix_len as u32; let suffix_len = tokens - prefix_len as u32;
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
match self.alloc_or_reclaim(suffix_len as usize) { match self.alloc_or_reclaim(suffix_len as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {

View File

@ -76,7 +76,7 @@ def paged_attention(
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import decode_state from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
query.contiguous(), query.contiguous(),
@ -234,7 +234,7 @@ if ATTENTION == "flashinfer":
softcap=0.0, softcap=0.0,
): ):
assert window_size_left == -1, "Windowing is not supported with flash infer" assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state, prefill_with_paged_kv_state,
) )

View File

@ -468,6 +468,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
log_master(logger.info, f"Block tables {block_tables}")
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -1028,7 +1029,7 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
create_prefill_state, create_prefill_state,
create_decode_state, create_decode_state,
create_prefill_with_paged_kv_state, create_prefill_with_paged_kv_state,
@ -1166,7 +1167,7 @@ class FlashCausalLM(Model):
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )
@ -1411,7 +1412,6 @@ class FlashCausalLM(Model):
).view(-1) ).view(-1)
prefix_lens_tensor = ( prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
+ arange_int
).view(-1) ).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -1452,6 +1452,13 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -1496,9 +1503,9 @@ class FlashCausalLM(Model):
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][ cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
: input_lengths.shape[0] input_lengths + prefix_lens_tensor
] = input_lengths # + prefix_lens_tensor )
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
@ -1736,7 +1743,7 @@ class FlashCausalLM(Model):
left = 0 left = 0
if n_accepted_ids > 1: if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") log_master(logger.info, f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
@ -1900,40 +1907,32 @@ class FlashCausalLM(Model):
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
return nullcontext() return nullcontext()
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
use_decode_state, use_decode_state,
use_prefill_state,
use_prefill_with_paged_kv_state, use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if True: # has_prefix_lens: log_master(logger.info, f"Prefix lens {prefix_lens}")
return use_prefill_with_paged_kv_state( return use_prefill_with_paged_kv_state(
state=( state=(
state if state is not None else self.prefill_with_paged_kv_state state if state is not None else self.prefill_with_paged_kv_state
), ),
block_tables=block_tables_to_ragged( # block_tables=block_tables_to_ragged(
block_tables=block_tables, # block_tables=block_tables,
input_lengths=input_lengths, # input_lengths=input_lengths,
prefix_lens=prefix_lens, # prefix_lens=prefix_lens,
), # ),
cu_seqlens=cu_seqlen_prefill, block_tables=block_tables,
input_lengths=input_lengths_tensor, cu_seqlens=cu_seqlen_prefill,
num_heads=self.num_heads, input_lengths=input_lengths_tensor,
num_kv_heads=self.num_kv_heads, num_heads=self.num_heads,
head_size=self.head_size, num_kv_heads=self.num_kv_heads,
page_size=BLOCK_SIZE, head_size=self.head_size,
) page_size=BLOCK_SIZE,
else: )
return use_prefill_state(
state=state if state is not None else self.prefill_state,
cu_seqlens=cu_seqlen_prefill,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
)
else: else:
assert input_lengths_tensor is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(