Prefix caching WIP

This commit is contained in:
Daniël de Kok 2024-08-09 11:47:14 +00:00 committed by Nicolas Patry
parent e4201f44cf
commit 44a77dcb9e
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
21 changed files with 308 additions and 53 deletions

View File

@ -205,6 +205,7 @@ pub struct RadixTrie {
/// call that a real time lookup would require. /// call that a real time lookup would require.
time: u64, time: u64,
} }
impl Default for RadixTrie { impl Default for RadixTrie {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()

View File

@ -6,7 +6,12 @@ from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex": elif SYSTEM == "ipex":

View File

@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
def attention( def attention(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
from text_generation_server.layers.attention.flash_infer import prefill_state assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flash_infer import (
prefill_with_paged_kv_state,
)
return prefill_state.get().forward( return prefill_with_paged_kv_state.get().forward(
q, q.contiguous(),
k,
v,
causal=causal, causal=causal,
window_left=window_size_left, paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
) )
@ -249,6 +252,8 @@ elif V2:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -289,6 +294,8 @@ else:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,

View File

@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
"prefill_state" "prefill_state"
) )
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state" "decode_state"
) )
@ -24,6 +28,78 @@ def get_workspace(device):
return workspace return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
page_size=page_size,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state( def create_prefill_state(
*, *,
device: torch.device, device: torch.device,

View File

@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -21,6 +21,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from loguru import logger
import torch import torch
import torch.distributed import torch.distributed
@ -220,6 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,

View File

@ -20,6 +20,9 @@ from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, D
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.layers.attention.flash_infer import (
create_prefill_with_paged_kv_state,
)
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
@ -43,6 +46,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -138,6 +142,9 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor slots: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
max_seqlen: int max_seqlen: int
@ -146,6 +153,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices: Optional[torch.tensor] prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]] prefill_cu_outlens: Optional[List[int]]
# Prefixes
prefix_ids: List[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
@ -213,6 +223,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True all_prefill_logprobs = True
@ -230,7 +241,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_max_length = 0 cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
num_blocks = 0 num_blocks = 0
@ -240,6 +251,7 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
slots = [] slots = []
prefix_lens = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -255,6 +267,19 @@ class FlashCausalLMBatch(Batch):
): ):
tokenized_input = tokenized_input[1:] tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
if PREFIX_CACHING:
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
else:
prefix_len = 0
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -264,7 +289,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32) request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
@ -288,11 +315,17 @@ class FlashCausalLMBatch(Batch):
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [ request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks) b for b in range(num_blocks, num_blocks + needed_blocks)
] ]
@ -303,16 +336,20 @@ class FlashCausalLMBatch(Batch):
] ]
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
]
block_tables.append(request_blocks) block_tables.append(request_blocks)
slots.extend(request_slots[:total_tokens])
slots.extend(request_slots)
prefix_lens.append(prefix_len)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
cumulative_max_length, cumulative_slot_tokens,
cumulative_max_length + input_length, cumulative_slot_tokens + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -348,7 +385,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += total_tokens cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
@ -425,12 +462,14 @@ class FlashCausalLMBatch(Batch):
) )
slots = torch.tensor(slots, dtype=torch.int64, device=device) slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -445,6 +484,8 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -455,6 +496,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -510,6 +552,7 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -536,6 +579,7 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
@ -621,6 +665,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -681,6 +726,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
@ -698,7 +744,9 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
prefix_lens = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -760,10 +808,14 @@ class FlashCausalLMBatch(Batch):
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
start_slots.append(batch.start_slots + cumulative_slots) start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
@ -809,6 +861,8 @@ class FlashCausalLMBatch(Batch):
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
@ -820,6 +874,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -976,6 +1031,9 @@ class FlashCausalLM(Model):
) )
self.prefill_state = create_prefill_state(device=device) self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
if not CUDA_GRAPHS: if not CUDA_GRAPHS:
self.decode_state = create_decode_state( self.decode_state = create_decode_state(
@ -1074,12 +1132,23 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = [max_s] * bs
block_tables = ( prefix_lengths = [0] * bs
torch.arange(max_bt, dtype=torch.int32, device=self.device) input_lengths_tensor = (
.repeat(bs) torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
.reshape((bs, max_bt))
) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lengths,
)
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
@ -1087,9 +1156,9 @@ class FlashCausalLM(Model):
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths) input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1104,7 +1173,7 @@ class FlashCausalLM(Model):
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs( state = create_decode_state_cuda_graphs(
device=input_ids.device, device=input_ids.device,
block_tables=block_tables.view(-1), block_tables=block_tables,
block_tables_ptr=block_tables_ptr, block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len, last_page_len=last_page_len,
num_heads=self.num_heads, num_heads=self.num_heads,
@ -1120,7 +1189,10 @@ class FlashCausalLM(Model):
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor,
): ):
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -1138,7 +1210,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1146,7 +1218,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths_tensor,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1375,7 +1447,10 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
@ -1399,19 +1474,32 @@ class FlashCausalLM(Model):
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][ if ATTENTION == "flashinfer":
: block_tables.shape[0], : block_tables.shape[1] block_tables = block_tables_to_ragged(
] = block_tables block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + batch.prefix_lens_tensor
)
state = cuda_graph.get("state") state = cuda_graph.get("state")
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
state=state, state=state,
): ):
# Replay the graph # Replay the graph
@ -1610,6 +1698,7 @@ class FlashCausalLM(Model):
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
@ -1627,6 +1716,7 @@ class FlashCausalLM(Model):
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
prefix_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
@ -1701,18 +1791,18 @@ class FlashCausalLM(Model):
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = (
out_start_index : out_end_index - 1 [float("nan")] * (len(prefix_ids) + 1)
] ) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefix_ids + prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefix_ids + prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,
prefill_texts, prefill_texts,
is_special=[], is_special=[],
@ -1794,7 +1884,10 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: torch.Tensor, input_lengths: List[int],
input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
@ -1803,24 +1896,65 @@ class FlashCausalLM(Model):
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flash_infer import (
use_decode_state, use_decode_state,
use_prefill_state, use_prefill_state,
use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_state( if True: # has_prefix_lens:
state=state if state is not None else self.prefill_state, return use_prefill_with_paged_kv_state(
cu_seqlens=cu_seqlen_prefill, state=(
num_heads=self.num_heads, state if state is not None else self.prefill_with_paged_kv_state
num_kv_heads=self.num_kv_heads, ),
head_size=self.head_size, block_tables=block_tables_to_ragged(
) block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lens,
),
cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
page_size=BLOCK_SIZE,
)
else:
return use_prefill_state(
state=state if state is not None else self.prefill_state,
cu_seqlens=cu_seqlen_prefill,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
)
else: else:
assert input_lengths is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths, input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables.view(-1), block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
) )
def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens)
total_len = sum(input_lengths) + sum(prefix_lens)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
seq_len = prefix_len + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged

View File

@ -29,7 +29,6 @@ elif ATTENTION == "flashinfer":
else: else:
BLOCK_SIZE = 16 BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try: