This commit is contained in:
Mohit Sharma 2025-04-15 13:44:04 +05:30 committed by GitHub
commit 302c773c99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 627 additions and 8 deletions

View File

@ -158,7 +158,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("falcon") | Some("deepseek_v2") => {
Some("falcon") | Some("deepseek_v2") | Some("llama4") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {

View File

@ -1,5 +1,5 @@
import math
from typing import List, Optional
from typing import List, Optional, Tuple, Dict
import torch
from opentelemetry import trace
@ -12,8 +12,10 @@ from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE, MEM_POOL
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
import torch.nn.functional as F
import numpy as np
tracer = trace.get_tracer(__name__)
@ -27,6 +29,126 @@ REPLICATED_ATTENTION_MODELS = [
]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.Tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
if ATTENTION == "flashdecoding":
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
)
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, (
f"attn_chunk_size {attn_chunk_size} is not "
f"divisible by page_size {page_size}"
)
pages_per_local_batch = attn_chunk_size // page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = np.broadcast_to(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch),
) + np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(
np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch,
)
block_table_local = block_table[batch_indices, block_indices].view(
virtual_batches, -1
)
else:
block_table_local = block_table
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
# # Qwen2VL
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
# "tgi"
@ -51,8 +173,14 @@ def tgi_flash_attention_forward(
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
use_sdpa: Optional[bool] = False,
seqlen_local: Optional[Seqlen] = None,
block_tables_local: Optional[torch.Tensor] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
):
if hasattr(module, "use_rope") and module.use_rope:
seqlen = seqlen_local
block_tables = block_tables_local
kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
@ -313,7 +441,9 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
return position_ids
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
def pre_process_inputs(self, **kwargs):
input_ids = kwargs["input_ids"]
position_ids = kwargs["position_ids"]
return {
"input_ids": input_ids.unsqueeze(0),
"position_ids": position_ids.unsqueeze(0),
@ -364,7 +494,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
seqlen_local: Optional[Seqlen] = None,
block_tables_local: Optional[torch.Tensor] = None,
):
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
@ -372,7 +505,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
seqlen=seqlen,
block_tables=block_tables,
)
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=inputs["input_ids"],
@ -396,6 +532,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
attention_mask=inputs.get("attention_mask", None),
use_sdpa=inputs.get("use_sdpa", False),
cache_position=inputs.get("cache_position", None),
seqlen_local=seqlen_local,
block_tables_local=block_tables_local,
).logits
logits = self.post_process_outputs(logits, lm_head_indices)
@ -480,7 +618,10 @@ class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
def post_process_outputs(self, logits, lm_head_indices):
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
def pre_process_inputs(self, **kwargs):
input_ids = kwargs["input_ids"]
position_ids = kwargs["position_ids"]
input_ids = input_ids.unsqueeze(0)
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
return {"input_ids": input_ids, "position_ids": position_ids}
@ -542,7 +683,11 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
return final_attention_mask
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
def pre_process_inputs(self, **kwargs):
input_ids = kwargs["input_ids"]
position_ids = kwargs["position_ids"]
cu_seqlen_prefill = kwargs["cu_seqlen_prefill"]
inputs = {
"input_ids": input_ids.unsqueeze(0),
"position_ids": position_ids.unsqueeze(0),
@ -559,8 +704,482 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill)
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
input_lengths = [max_s] * bs
cache_lengths = [0] * bs
if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
config = getattr(self.model, "config", None)
rope_scaling = getattr(config, "rope_scaling", None) if config else None
if ( # mrope have position_ids per section, if so repeat n times
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
):
n_sections = len(self.model.config.rope_scaling["mrope_section"])
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
)
cache_lengths_tensor = torch.zeros(
bs, dtype=torch.int32, device=self.device
)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
cu_seqlen_q = torch.arange(
input_lengths_tensor.shape[0] + 1,
device=self.device,
dtype=torch.int32,
)
(
input_lengths_tensor_local,
cache_lengths_tensor_local,
seqlens_q_local,
max_q,
max_k,
block_tables_local,
) = self.get_chunked_attention_seqlen(
cu_seqlen_q,
input_lengths_tensor,
block_tables,
)
self.max_k_local = max_k
else:
if bs > max_bs:
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
else:
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
block_tables_local = self.cuda_graphs[max_bs]["block_tables_local"][:bs]
slots = self.cuda_graphs[max_bs]["slots"][:bs]
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
input_lengths_tensor_local = self.cuda_graphs[max_bs][
"input_lengths_local"
][:bs]
cache_lengths_tensor_local = self.cuda_graphs[max_bs][
"cache_lengths_local"
][:bs]
seqlens_q_local = self.cuda_graphs[max_bs]["seqlens_q_local"][:bs]
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
)
block_tables_ptr = torch.zeros(
bs + 1, dtype=torch.int32, device=self.device
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
block_tables=block_tables,
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
else:
state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"cache_lengths": cache_lengths_tensor,
"input_lengths_local": input_lengths_tensor_local,
"cache_lengths_local": cache_lengths_tensor_local,
"seqlens_q_local": seqlens_q_local,
"block_tables_local": block_tables_local,
"state": state,
"graph": graph,
}
torch.cuda.synchronize()
# Run once outside to warmup
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor,
state=state,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
# cu_seqlens_q_local = F.pad(
# torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
# ).to(torch.int32)
seqlen_local = Seqlen(
input_lengths=input_lengths_tensor_local,
cache_lengths=cache_lengths_tensor_local,
cu_seqlen_q=None,
max_q=1,
max_k=input_lengths_tensor_local.max(),
)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
seqlen_local=seqlen_local,
block_tables_local=block_tables_local,
)
del seqlen
del seqlen_local
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
# cu_seqlens_q_local = F.pad(
# torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
# ).to(torch.int32)
seqlen_local = Seqlen(
input_lengths=input_lengths_tensor_local,
cache_lengths=cache_lengths_tensor_local,
cu_seqlen_q=None,
max_q=1,
max_k=input_lengths_tensor_local.max(),
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
seqlen_local=seqlen_local,
block_tables_local=block_tables_local,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def get_chunked_attention_seqlen(
self,
cu_seqlen_q,
seq_lens_np,
block_tables,
):
attention_chunk_size = self.model.config.text_config.attention_chunk_size
# seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
(
seqlens_q_local_np,
virt_q_cu_seqlens_np,
virt_k_seqlens_np,
virt_block_table,
) = make_local_attention_virtual_batches(
attention_chunk_size,
(
cu_seqlen_q.cpu().numpy()
if isinstance(cu_seqlen_q, torch.Tensor)
else cu_seqlen_q
),
seq_lens_np.cpu().numpy(),
block_tables,
BLOCK_SIZE,
)
input_lengths = torch.from_numpy(virt_k_seqlens_np).to(
cu_seqlen_q.device, non_blocking=True
)
cache_lengths = torch.zeros(virt_k_seqlens_np.shape).to(
cu_seqlen_q.device, non_blocking=True
)
seqlens_q_local = torch.from_numpy(seqlens_q_local_np).to(
cu_seqlen_q.device, non_blocking=True
)
max_q = int(seqlens_q_local_np.max())
max_k = int(virt_k_seqlens_np.max())
return (
input_lengths,
cache_lengths,
seqlens_q_local,
max_q,
max_k,
virt_block_table,
)
def pre_process_inputs(self, **kwargs):
input_ids = kwargs["input_ids"]
position_ids = kwargs["position_ids"]
inputs = super().pre_process_inputs(**kwargs)
inputs["cache_position"] = position_ids
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
return inputs
def forward(
self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
)
batch.position_ids = position_ids
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
cu_seqlen_q = (
cu_seqlen_prefill
if cu_seqlen_prefill is not None
else torch.arange(
input_lengths.shape[0] + 1, dtype=torch.int32, device=input_ids.device
)
)
(
input_lengths_tensor_local,
cache_lengths_tensor_local,
seqlens_q_local,
max_q,
max_k,
block_tables_local,
) = self.get_chunked_attention_seqlen(
cu_seqlen_q=cu_seqlen_q,
seq_lens_np=input_lengths + cache_lengths_tensor,
block_tables=block_tables,
)
if cu_seqlen_prefill is not None or cuda_graph is None:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
raise RuntimeError("Flashinfer for LLama4 is not supported yet")
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
cu_seqlens_q_local = F.pad(
torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
).to(torch.int32)
seqlen_local = Seqlen(
input_lengths=input_lengths_tensor_local,
cache_lengths=cache_lengths_tensor_local,
cu_seqlen_q=cu_seqlens_q_local,
max_q=max_q,
max_k=max_k,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
seqlen_local=seqlen_local,
block_tables_local=block_tables_local,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
raise RuntimeError("Flashinfer for LLama4 is not supported yet")
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["block_tables_local"][
: block_tables_local.shape[0], : block_tables_local.shape[1]
] = block_tables_local
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
cuda_graph["input_lengths_local"].zero_()
cuda_graph["input_lengths_local"][
: input_lengths_tensor_local.shape[0]
] = input_lengths_tensor_local
cuda_graph["cache_lengths_local"].zero_()
cuda_graph["cache_lengths_local"][
: cache_lengths_tensor_local.shape[0]
] = cache_lengths_tensor_local
cuda_graph["seqlens_q_local"].zero_()
cuda_graph["seqlens_q_local"][: seqlens_q_local.shape[0]] = seqlens_q_local
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits