mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
This change adds support for FlashInfer. FlashInfer can be enabled using `FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`. Since this functionality is currently only for testing, FlashInfer is not installed anywhere yet. The FlashInfer API is quite different from FlashAttention/vLLM in that it requires more global bookkeeping: * A wrapper class needs to be contstructed (which we just call *state*). Since this is fairly expensive (due to pinned host memory allocation), we only do this once in a FlashCausalLM instance or for each CUDA Graph size. * Each model forward call needs to be wrapped in `begin_forward` and `end_forward`. This sets up data structures that can be reused for all calls to attention for that forward call. When calling attention, we need access to the state object. To avoid passing an argument down the call chain (which would require changes to all models), we use a context variable. Each model forward call is wrapped using a context manager that does all the bookkeeping for such a call: * Set the context variable to the forward call's state. * Call `begin_forward` on the state. * Yield. * Call `end_forward` on the state. * Reset the context variable. We cannot use a single shared global variable for this, since e.g. CUDA Graphs of different sizes each have their own state.
165 lines
4.5 KiB
Python
165 lines
4.5 KiB
Python
from typing import Optional
|
|
from contextvars import ContextVar
|
|
from contextlib import contextmanager
|
|
|
|
import flashinfer
|
|
import torch
|
|
|
|
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
|
|
"prefill_state"
|
|
)
|
|
|
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
|
"decode_state"
|
|
)
|
|
|
|
workspace: Optional[torch.Tensor] = None
|
|
|
|
|
|
def get_workspace(device):
|
|
"""Get shared flashinfer workspace."""
|
|
global workspace
|
|
if workspace is None:
|
|
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
|
return workspace
|
|
|
|
|
|
def create_prefill_state(
|
|
*,
|
|
device: torch.device,
|
|
):
|
|
"""Create a prefill state."""
|
|
workspace_buffer = get_workspace(device)
|
|
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def use_prefill_state(
|
|
*,
|
|
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
|
cu_seqlens: torch.Tensor,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
query_dtype: str = "float16",
|
|
):
|
|
"""
|
|
Context manager to set the active flashinfer prefill state to the given
|
|
`state` and parameters. This state will be used by all calls to the
|
|
`attention` function while the context manager is active.
|
|
"""
|
|
|
|
token = prefill_state.set(state)
|
|
try:
|
|
state.begin_forward(
|
|
qo_indptr=cu_seqlens,
|
|
kv_indptr=cu_seqlens,
|
|
num_qo_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_size,
|
|
q_data_type=query_dtype,
|
|
)
|
|
yield
|
|
finally:
|
|
state.end_forward()
|
|
if token is not None:
|
|
prefill_state.reset(token)
|
|
|
|
|
|
def create_decode_state(
|
|
*,
|
|
device: torch.device,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
):
|
|
"""Create a decode state."""
|
|
workspace_buffer = get_workspace(device)
|
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
workspace_buffer,
|
|
kv_layout="NHD",
|
|
use_cuda_graph=False,
|
|
use_tensor_cores=num_heads // num_kv_heads > 4,
|
|
)
|
|
|
|
|
|
def create_decode_state_cuda_graphs(
|
|
*,
|
|
device: torch.device,
|
|
block_tables: torch.Tensor,
|
|
block_tables_ptr: torch.Tensor,
|
|
last_page_len: torch.Tensor,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
):
|
|
"""
|
|
Create a decode state for use with CUDA Graphs. `block_tables`,
|
|
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
|
|
therefore stored as part of the state.
|
|
"""
|
|
workspace_buffer = get_workspace(device)
|
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
workspace_buffer,
|
|
kv_layout="NHD",
|
|
use_cuda_graph=True,
|
|
paged_kv_indices_buffer=block_tables,
|
|
paged_kv_indptr_buffer=block_tables_ptr,
|
|
paged_kv_last_page_len_buffer=last_page_len,
|
|
use_tensor_cores=num_heads // num_kv_heads > 4,
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def use_decode_state(
|
|
*,
|
|
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
|
|
input_lengths: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
page_size: int,
|
|
query_dtype: str = "float16",
|
|
):
|
|
"""
|
|
Context manager to set the active flashinfer decoding state to the given
|
|
`state` and parameters. This state will be used by all calls to the
|
|
`paged_attention` function while the context manager is active.
|
|
"""
|
|
indptr = torch.zeros(
|
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
|
)
|
|
# Round up to page size and then calculate the cumulative sum to get
|
|
# the indices into the block table.
|
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
|
indptr[1:].cumsum_(-1)
|
|
|
|
# Get the lengths of the last page in a block.
|
|
last_page_len = torch.empty(
|
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
)
|
|
torch.sub(input_lengths, 1, out=last_page_len)
|
|
last_page_len.remainder_(page_size)
|
|
last_page_len += 1
|
|
|
|
token = decode_state.set(state)
|
|
|
|
try:
|
|
state.begin_forward(
|
|
indptr=indptr,
|
|
indices=block_tables,
|
|
last_page_len=last_page_len,
|
|
num_qo_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_size,
|
|
page_size=page_size,
|
|
q_data_type=query_dtype,
|
|
)
|
|
yield
|
|
finally:
|
|
state.end_forward()
|
|
if token is not None:
|
|
decode_state.reset(token)
|