mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
This change doesn't switch `forward` to `run` yet, since it requires that we have access to the softmax scale and the logit softcap outside the model.
218 lines
6.6 KiB
Python
218 lines
6.6 KiB
Python
from typing import Optional
|
|
from contextvars import ContextVar
|
|
from contextlib import contextmanager
|
|
|
|
import flashinfer
|
|
import torch
|
|
|
|
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
|
|
"prefill_state"
|
|
)
|
|
|
|
prefill_with_paged_kv_state: ContextVar[
|
|
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
|
] = ContextVar("prefill_with_paged_kv_state")
|
|
|
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
|
"decode_state"
|
|
)
|
|
|
|
workspace: Optional[torch.Tensor] = None
|
|
|
|
|
|
def get_workspace(device):
|
|
"""Get shared flashinfer workspace."""
|
|
global workspace
|
|
if workspace is None:
|
|
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
|
return workspace
|
|
|
|
|
|
def create_prefill_with_paged_kv_state(
|
|
*,
|
|
device: torch.device,
|
|
):
|
|
"""Create a prefill state that uses the KV cache."""
|
|
workspace_buffer = get_workspace(device)
|
|
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def use_prefill_with_paged_kv_state(
|
|
*,
|
|
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
|
block_tables: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
input_lengths: torch.Tensor,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
page_size: int,
|
|
kv_dtype: torch.dtype,
|
|
q_dtype: torch.dtype,
|
|
window_left: int,
|
|
):
|
|
"""
|
|
Context manager to set the active flashinfer prefill state to the given
|
|
`state` and parameters. This state will be used by all calls to the
|
|
`attention` function while the context manager is active.
|
|
"""
|
|
|
|
indptr = torch.zeros(
|
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
|
)
|
|
# Round up to page size and then calculate the cumulative sum to get
|
|
# the indices into the block table.
|
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
|
indptr[1:].cumsum_(-1)
|
|
|
|
# Get the lengths of the last page in a block.
|
|
if page_size == 1:
|
|
last_page_len = torch.ones(
|
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
)
|
|
else:
|
|
last_page_len = torch.empty(
|
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
)
|
|
torch.sub(input_lengths, 1, out=last_page_len)
|
|
last_page_len.remainder_(page_size)
|
|
last_page_len += 1
|
|
|
|
token = prefill_with_paged_kv_state.set(state)
|
|
try:
|
|
state.plan(
|
|
qo_indptr=cu_seqlens,
|
|
paged_kv_indptr=indptr,
|
|
paged_kv_indices=block_tables,
|
|
paged_kv_last_page_len=last_page_len,
|
|
num_qo_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_size,
|
|
kv_data_type=kv_dtype,
|
|
q_data_type=q_dtype,
|
|
page_size=page_size,
|
|
window_left=-1 if window_left is None else window_left,
|
|
)
|
|
yield
|
|
finally:
|
|
if token is not None:
|
|
prefill_with_paged_kv_state.reset(token)
|
|
|
|
|
|
def create_prefill_state(
|
|
*,
|
|
device: torch.device,
|
|
):
|
|
"""Create a prefill state."""
|
|
workspace_buffer = get_workspace(device)
|
|
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
|
)
|
|
|
|
|
|
def create_decode_state(
|
|
*,
|
|
device: torch.device,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
):
|
|
"""Create a decode state."""
|
|
workspace_buffer = get_workspace(device)
|
|
num_groups = num_heads // num_kv_heads
|
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
workspace_buffer,
|
|
kv_layout="NHD",
|
|
use_cuda_graph=False,
|
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
|
)
|
|
|
|
|
|
def create_decode_state_cuda_graphs(
|
|
*,
|
|
device: torch.device,
|
|
block_tables: torch.Tensor,
|
|
block_tables_ptr: torch.Tensor,
|
|
last_page_len: torch.Tensor,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
):
|
|
"""
|
|
Create a decode state for use with CUDA Graphs. `block_tables`,
|
|
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
|
|
therefore stored as part of the state.
|
|
"""
|
|
workspace_buffer = get_workspace(device)
|
|
num_groups = num_heads // num_kv_heads
|
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
workspace_buffer,
|
|
kv_layout="NHD",
|
|
use_cuda_graph=True,
|
|
paged_kv_indices_buffer=block_tables,
|
|
paged_kv_indptr_buffer=block_tables_ptr,
|
|
paged_kv_last_page_len_buffer=last_page_len,
|
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def use_decode_state(
|
|
*,
|
|
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
|
|
input_lengths: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
page_size: int,
|
|
kv_cache_dtype: torch.dtype,
|
|
q_dtype: torch.dtype,
|
|
window_left: int,
|
|
):
|
|
"""
|
|
Context manager to set the active flashinfer decoding state to the given
|
|
`state` and parameters. This state will be used by all calls to the
|
|
`paged_attention` function while the context manager is active.
|
|
"""
|
|
indptr = torch.zeros(
|
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
|
)
|
|
# Round up to page size and then calculate the cumulative sum to get
|
|
# the indices into the block table.
|
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
|
indptr[1:].cumsum_(-1)
|
|
|
|
# Get the lengths of the last page in a block.
|
|
last_page_len = torch.empty(
|
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
|
)
|
|
torch.sub(input_lengths, 1, out=last_page_len)
|
|
last_page_len.remainder_(page_size)
|
|
last_page_len += 1
|
|
|
|
token = decode_state.set(state)
|
|
|
|
try:
|
|
state.plan(
|
|
indptr=indptr,
|
|
indices=block_tables,
|
|
last_page_len=last_page_len,
|
|
num_qo_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_size,
|
|
page_size=page_size,
|
|
data_type=kv_cache_dtype,
|
|
q_data_type=q_dtype,
|
|
window_left=-1 if window_left is None else window_left,
|
|
)
|
|
yield
|
|
finally:
|
|
if token is not None:
|
|
decode_state.reset(token)
|