mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Prefix caching (#2402)
* Prefix caching WIP * Fixing prefix attention. * Fixing flashinfer import. * Fixing black. * Fixing medusa (still wrong outputs, but functional). * Just medusa values now. * Fixing medusa without prefix caching. * Fixing prefix caching. * Medusa requires reshaping. * Removing the logs. * Remove router.nix * Fixup: - Remove logs - Disable VLMs (they do not work) - Disable prefix caching when user wants prefill logprobs. * Update flake.lock --------- Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
parent
38773453ae
commit
b70ae0969f
@ -316,10 +316,15 @@ impl State {
|
|||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
|
|
||||||
match block_allocator
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
.allocate(tokens, entry.request.input_ids.clone())
|
// So no input_ids for the radix tree.
|
||||||
.await
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
{
|
None
|
||||||
|
} else {
|
||||||
|
entry.request.input_ids.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
match block_allocator.allocate(tokens, input_ids).await {
|
||||||
None => {
|
None => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
|
@ -205,6 +205,7 @@ pub struct RadixTrie {
|
|||||||
/// call that a real time lookup would require.
|
/// call that a real time lookup would require.
|
||||||
time: u64,
|
time: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RadixTrie {
|
impl Default for RadixTrie {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new()
|
||||||
|
@ -900,11 +900,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723515680,
|
"lastModified": 1723602049,
|
||||||
"narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=",
|
"narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3",
|
"rev": "ea0bf33a11a26a62c60123c49d96011da396602c",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -84,6 +84,7 @@
|
|||||||
grpcio-status
|
grpcio-status
|
||||||
grpcio-tools
|
grpcio-tools
|
||||||
hf-transfer
|
hf-transfer
|
||||||
|
ipdb
|
||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
marlin-kernels
|
marlin-kernels
|
||||||
|
@ -6,7 +6,12 @@ from .common import Seqlen
|
|||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
|
@ -76,7 +76,7 @@ def paged_attention(
|
|||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
|
|||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
|
|||||||
causal=True,
|
causal=True,
|
||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
from text_generation_server.layers.attention.flash_infer import prefill_state
|
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
return prefill_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
q,
|
q.contiguous(),
|
||||||
k,
|
|
||||||
v,
|
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_left=window_size_left,
|
paged_kv_cache=(key_cache, value_cache),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
)
|
)
|
||||||
@ -249,6 +252,8 @@ elif V2:
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -289,6 +294,8 @@ else:
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
|
|||||||
"prefill_state"
|
"prefill_state"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefill_with_paged_kv_state: ContextVar[
|
||||||
|
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
] = ContextVar("prefill_with_paged_kv_state")
|
||||||
|
|
||||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||||
"decode_state"
|
"decode_state"
|
||||||
)
|
)
|
||||||
@ -24,6 +28,78 @@ def get_workspace(device):
|
|||||||
return workspace
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def create_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Create a prefill state that uses the KV cache."""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
page_size: int,
|
||||||
|
query_dtype: str = "float16",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
|
`state` and parameters. This state will be used by all calls to the
|
||||||
|
`attention` function while the context manager is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indptr = torch.zeros(
|
||||||
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
# Round up to page size and then calculate the cumulative sum to get
|
||||||
|
# the indices into the block table.
|
||||||
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||||
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||||
|
indptr[1:].cumsum_(-1)
|
||||||
|
|
||||||
|
# Get the lengths of the last page in a block.
|
||||||
|
if page_size == 1:
|
||||||
|
last_page_len = torch.ones(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_page_len = torch.empty(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
torch.sub(input_lengths, 1, out=last_page_len)
|
||||||
|
last_page_len.remainder_(page_size)
|
||||||
|
last_page_len += 1
|
||||||
|
|
||||||
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
|
try:
|
||||||
|
state.begin_forward(
|
||||||
|
qo_indptr=cu_seqlens,
|
||||||
|
paged_kv_indptr=indptr,
|
||||||
|
paged_kv_indices=block_tables,
|
||||||
|
paged_kv_last_page_len=last_page_len,
|
||||||
|
num_qo_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_size,
|
||||||
|
q_data_type=query_dtype,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
state.end_forward()
|
||||||
|
if token is not None:
|
||||||
|
prefill_with_paged_kv_state.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def create_prefill_state(
|
def create_prefill_state(
|
||||||
*,
|
*,
|
||||||
device: torch.device,
|
device: torch.device,
|
@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if not self.heads:
|
||||||
|
return None
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
return speculative_logits
|
return speculative_logits
|
||||||
|
|
||||||
|
@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -43,6 +43,7 @@ from text_generation_server.models.globals import (
|
|||||||
ATTENTION,
|
ATTENTION,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
|
PREFIX_CACHING,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: torch.Tensor
|
slots: torch.Tensor
|
||||||
|
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||||
|
prefix_lens: List[int]
|
||||||
|
prefix_lens_tensor: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_next_token_indices: Optional[torch.tensor]
|
prefill_next_token_indices: Optional[torch.tensor]
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
|
# Prefixes
|
||||||
|
prefix_ids: List[List[int]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_max_length = 0
|
cumulative_slot_tokens = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
|
prefix_lens = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
):
|
):
|
||||||
tokenized_input = tokenized_input[1:]
|
tokenized_input = tokenized_input[1:]
|
||||||
|
|
||||||
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
prefix_len = r.prefix_len
|
||||||
|
if prefix_len == orig_input_length:
|
||||||
|
assert prefix_len > 0
|
||||||
|
prefix_len -= 1
|
||||||
|
else:
|
||||||
|
prefix_len = 0
|
||||||
|
|
||||||
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
request_position_ids = torch.arange(
|
||||||
|
prefix_len, orig_input_length, dtype=torch.int32
|
||||||
|
)
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
|
||||||
|
# Tokens that need to be mapped to blocks.
|
||||||
|
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
|
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||||
|
# cached prefix (if present).
|
||||||
|
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
]
|
]
|
||||||
@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots
|
request_slots = r.slots[
|
||||||
|
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
|
||||||
|
]
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
slots.extend(request_slots[:total_tokens])
|
|
||||||
|
slots.extend(request_slots)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_slot_tokens)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
cumulative_max_length,
|
cumulative_slot_tokens,
|
||||||
cumulative_max_length + input_length,
|
cumulative_slot_tokens + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
cumulative_slot_tokens += slot_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
prefix_lens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
prefix_len = self.prefix_lens[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
prefix_lens_tensor = self.prefix_lens_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
|
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
prefix_lens = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
|
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
|
prefix_lens.extend(batch.prefix_lens)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -970,14 +1028,17 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_prefill_state,
|
create_prefill_state,
|
||||||
create_decode_state,
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefill_state = create_prefill_state(device=device)
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
if not CUDA_GRAPHS:
|
|
||||||
self.decode_state = create_decode_state(
|
self.decode_state = create_decode_state(
|
||||||
device=device,
|
device=device,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@ -1074,11 +1135,22 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = [max_s] * bs
|
||||||
block_tables = (
|
prefix_lengths = [0] * bs
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
input_lengths_tensor = (
|
||||||
.repeat(bs)
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
.reshape((bs, max_bt))
|
)
|
||||||
|
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
@ -1087,14 +1159,14 @@ class FlashCausalLM(Model):
|
|||||||
"kv_cache": self.kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths_tensor,
|
||||||
}
|
}
|
||||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1104,7 +1176,7 @@ class FlashCausalLM(Model):
|
|||||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
state = create_decode_state_cuda_graphs(
|
state = create_decode_state_cuda_graphs(
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
block_tables=block_tables.view(-1),
|
block_tables=block_tables,
|
||||||
block_tables_ptr=block_tables_ptr,
|
block_tables_ptr=block_tables_ptr,
|
||||||
last_page_len=last_page_len,
|
last_page_len=last_page_len,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@ -1120,7 +1192,10 @@ class FlashCausalLM(Model):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
|
prefix_lens_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -1138,7 +1213,7 @@ class FlashCausalLM(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -1146,7 +1221,7 @@ class FlashCausalLM(Model):
|
|||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_tensor,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
@ -1334,6 +1409,9 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -1354,6 +1432,7 @@ class FlashCausalLM(Model):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -1372,10 +1451,20 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
@ -1399,20 +1488,32 @@ class FlashCausalLM(Model):
|
|||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
state = cuda_graph.get("state")
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=cuda_graph["block_tables"],
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
state=state,
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
state=cuda_graph.get("state"),
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
@ -1610,6 +1711,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
batch.prefix_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
@ -1627,6 +1729,7 @@ class FlashCausalLM(Model):
|
|||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
prefix_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
@ -1701,18 +1804,18 @@ class FlashCausalLM(Model):
|
|||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
request_prefill_logprobs = (
|
||||||
out_start_index : out_end_index - 1
|
[float("nan")] * (len(prefix_ids) + 1)
|
||||||
]
|
) + prefill_logprobs[out_start_index : out_end_index - 1]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
prefill_texts,
|
prefill_texts,
|
||||||
is_special=[],
|
is_special=[],
|
||||||
@ -1794,33 +1897,68 @@ class FlashCausalLM(Model):
|
|||||||
*,
|
*,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: List[int],
|
||||||
|
input_lengths_tensor: torch.Tensor,
|
||||||
|
prefix_lens: List[int],
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
use_decode_state,
|
use_decode_state,
|
||||||
use_prefill_state,
|
use_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
return use_prefill_state(
|
return use_prefill_with_paged_kv_state(
|
||||||
state=state if state is not None else self.prefill_state,
|
state=(
|
||||||
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
|
),
|
||||||
|
# block_tables=block_tables_to_ragged(
|
||||||
|
# block_tables=block_tables,
|
||||||
|
# input_lengths=input_lengths,
|
||||||
|
# prefix_lens=prefix_lens,
|
||||||
|
# ),
|
||||||
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
num_heads=self.num_heads,
|
input_lengths=input_lengths_tensor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert input_lengths is not None
|
|
||||||
return use_decode_state(
|
|
||||||
state=state if state is not None else self.decode_state,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
block_tables=block_tables.view(-1),
|
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert input_lengths_tensor is not None
|
||||||
|
return use_decode_state(
|
||||||
|
state=state if state is not None else self.decode_state,
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
block_tables=block_tables,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
page_size=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def block_tables_to_ragged(
|
||||||
|
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
|
assert len(input_lengths) == len(prefix_lens)
|
||||||
|
|
||||||
|
total_len = sum(input_lengths) + sum(prefix_lens)
|
||||||
|
block_tables_ragged = torch.empty(
|
||||||
|
total_len, dtype=torch.int32, device=block_tables.device
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
|
||||||
|
seq_len = prefix_len + input_length
|
||||||
|
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
return block_tables_ragged
|
||||||
|
@ -5,9 +5,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
|
|
||||||
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer":
|
|||||||
else:
|
else:
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
|
@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
|
block_tables_to_ragged,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||||
if processor_kwargs is None:
|
if processor_kwargs is None:
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
self.processor = processor_class.from_pretrained(
|
self.processor = processor_class.from_pretrained(
|
||||||
@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
Loading…
Reference in New Issue
Block a user