text-generation-inference/server/text_generation_server/layers/attention/flash_infer.py
Daniël de Kok 7830de1566
Add FlashInfer support (#2354)
This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.

The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:

* A wrapper class needs to be contstructed (which we just call *state*).
  Since this is fairly expensive (due to pinned host memory allocation),
  we only do this once in a FlashCausalLM instance or for each CUDA
  Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
  `end_forward`. This sets up data structures that can be reused for all
  calls to attention for that forward call.

When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.

Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:

* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.

We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
2024-08-09 11:42:00 +02:00

165 lines
4.5 KiB
Python

from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import flashinfer
import torch
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
workspace: Optional[torch.Tensor] = None
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace
def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
)
def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
)
@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = decode_state.set(state)
try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
q_data_type=query_dtype,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)