from typing import Optional from contextvars import ContextVar from contextlib import contextmanager import flashinfer import torch prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( "prefill_state" ) decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( "decode_state" ) workspace: Optional[torch.Tensor] = None def get_workspace(device): """Get shared flashinfer workspace.""" global workspace if workspace is None: workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) return workspace def create_prefill_state( *, device: torch.device, ): """Create a prefill state.""" workspace_buffer = get_workspace(device) return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout="NHD", use_cuda_graph=False ) @contextmanager def use_prefill_state( *, state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, cu_seqlens: torch.Tensor, num_heads: int, num_kv_heads: int, head_size: int, query_dtype: str = "float16", ): """ Context manager to set the active flashinfer prefill state to the given `state` and parameters. This state will be used by all calls to the `attention` function while the context manager is active. """ token = prefill_state.set(state) try: state.begin_forward( qo_indptr=cu_seqlens, kv_indptr=cu_seqlens, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, q_data_type=query_dtype, ) yield finally: state.end_forward() if token is not None: prefill_state.reset(token) def create_decode_state( *, device: torch.device, num_heads: int, num_kv_heads: int, ): """Create a decode state.""" workspace_buffer = get_workspace(device) return flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD", use_cuda_graph=False, use_tensor_cores=num_heads // num_kv_heads > 4, ) def create_decode_state_cuda_graphs( *, device: torch.device, block_tables: torch.Tensor, block_tables_ptr: torch.Tensor, last_page_len: torch.Tensor, num_heads: int, num_kv_heads: int, ): """ Create a decode state for use with CUDA Graphs. `block_tables`, `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are therefore stored as part of the state. """ workspace_buffer = get_workspace(device) return flashinfer.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD", use_cuda_graph=True, paged_kv_indices_buffer=block_tables, paged_kv_indptr_buffer=block_tables_ptr, paged_kv_last_page_len_buffer=last_page_len, use_tensor_cores=num_heads // num_kv_heads > 4, ) @contextmanager def use_decode_state( *, state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, input_lengths: torch.Tensor, block_tables: torch.Tensor, num_heads: int, num_kv_heads: int, head_size: int, page_size: int, query_dtype: str = "float16", ): """ Context manager to set the active flashinfer decoding state to the given `state` and parameters. This state will be used by all calls to the `paged_attention` function while the context manager is active. """ indptr = torch.zeros( input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 ) # Round up to page size and then calculate the cumulative sum to get # the indices into the block table. torch.add(input_lengths, page_size - 1, out=indptr[1:]) indptr[1:].div_(page_size, rounding_mode="floor") indptr[1:].cumsum_(-1) # Get the lengths of the last page in a block. last_page_len = torch.empty( input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device ) torch.sub(input_lengths, 1, out=last_page_len) last_page_len.remainder_(page_size) last_page_len += 1 token = decode_state.set(state) try: state.begin_forward( indptr=indptr, indices=block_tables, last_page_len=last_page_len, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, page_size=page_size, q_data_type=query_dtype, ) yield finally: state.end_forward() if token is not None: decode_state.reset(token)