mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
support cuda graphs
This commit is contained in:
parent
3f343cdb6f
commit
d2f8caff2b
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
import torch
|
||||
from opentelemetry import trace
|
||||
@ -12,7 +12,8 @@ from text_generation_server.utils import initialize_torch_distributed
|
||||
|
||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE, MEM_POOL
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
@ -172,13 +173,13 @@ def tgi_flash_attention_forward(
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
use_sdpa: Optional[bool] = False,
|
||||
local_seqlen: Optional[Seqlen] = None,
|
||||
local_block_tables: Optional[torch.Tensor] = None,
|
||||
seqlen_local: Optional[Seqlen] = None,
|
||||
block_tables_local: Optional[torch.Tensor] = None,
|
||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||
):
|
||||
if hasattr(module, "use_rope") and module.use_rope:
|
||||
seqlen = local_seqlen
|
||||
block_tables = local_block_tables
|
||||
seqlen = seqlen_local
|
||||
block_tables = block_tables_local
|
||||
|
||||
kv_cache = kv_cache[module.layer_idx]
|
||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||
@ -493,7 +494,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
seqlen_local: Optional[Seqlen] = None,
|
||||
block_tables_local: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||
|
||||
@ -505,13 +509,6 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}"
|
||||
)
|
||||
|
||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||
logits = self.model.original_forward(
|
||||
input_ids=inputs["input_ids"],
|
||||
@ -535,8 +532,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
use_sdpa=inputs.get("use_sdpa", False),
|
||||
cache_position=inputs.get("cache_position", None),
|
||||
local_seqlen=inputs.get("local_seqlen", None),
|
||||
local_block_tables=inputs.get("local_block_tables", None),
|
||||
seqlen_local=seqlen_local,
|
||||
block_tables_local=block_tables_local,
|
||||
).logits
|
||||
|
||||
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||
@ -707,48 +704,481 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
|
||||
|
||||
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
def pre_process_inputs(self, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
position_ids = kwargs["position_ids"]
|
||||
seqlen = kwargs["seqlen"]
|
||||
block_tables = kwargs["block_tables"]
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||
input_lengths = [max_s] * bs
|
||||
cache_lengths = [0] * bs
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
config = getattr(self.model, "config", None)
|
||||
rope_scaling = getattr(config, "rope_scaling", None) if config else None
|
||||
if ( # mrope have position_ids per section, if so repeat n times
|
||||
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
|
||||
):
|
||||
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
||||
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
bs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
block_tables = block_tables.reshape((bs, max_bt))
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
max_current_length=max_s,
|
||||
)
|
||||
|
||||
inputs = super().pre_process_inputs(**kwargs)
|
||||
inputs["cache_position"] = position_ids
|
||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||
# from loguru import logger
|
||||
cu_seqlen_q = torch.arange(
|
||||
input_lengths_tensor.shape[0] + 1,
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
(
|
||||
input_lengths_tensor_local,
|
||||
cache_lengths_tensor_local,
|
||||
seqlens_q_local,
|
||||
max_q,
|
||||
max_k,
|
||||
block_tables_local,
|
||||
) = self.get_chunked_attention_seqlen(
|
||||
cu_seqlen_q,
|
||||
input_lengths_tensor,
|
||||
block_tables,
|
||||
)
|
||||
self.max_k_local = max_k
|
||||
else:
|
||||
if bs > max_bs:
|
||||
raise RuntimeError(
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||
else:
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||
|
||||
input_lengths_tensor_local = self.cuda_graphs[max_bs][
|
||||
"input_lengths_local"
|
||||
][:bs]
|
||||
cache_lengths_tensor_local = self.cuda_graphs[max_bs][
|
||||
"cache_lengths_local"
|
||||
][:bs]
|
||||
seqlens_q_local = self.cuda_graphs[max_bs]["seqlens_q_local"][:bs]
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
||||
block_tables_ptr = torch.zeros(
|
||||
bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
block_tables=block_tables,
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
state = None
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"cache_lengths": cache_lengths_tensor,
|
||||
"input_lengths_local": input_lengths_tensor_local,
|
||||
"cache_lengths_local": cache_lengths_tensor_local,
|
||||
"seqlens_q_local": seqlens_q_local,
|
||||
"block_tables_local": block_tables_local,
|
||||
"state": state,
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
# cu_seqlens_q_local = F.pad(
|
||||
# torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
|
||||
# ).to(torch.int32)
|
||||
seqlen_local = Seqlen(
|
||||
input_lengths=input_lengths_tensor_local,
|
||||
cache_lengths=cache_lengths_tensor_local,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=input_lengths_tensor_local.max(),
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
seqlen_local=seqlen_local,
|
||||
block_tables_local=block_tables_local,
|
||||
)
|
||||
del seqlen
|
||||
del seqlen_local
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
# cu_seqlens_q_local = F.pad(
|
||||
# torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
|
||||
# ).to(torch.int32)
|
||||
seqlen_local = Seqlen(
|
||||
input_lengths=input_lengths_tensor_local,
|
||||
cache_lengths=cache_lengths_tensor_local,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=input_lengths_tensor_local.max(),
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
seqlen_local=seqlen_local,
|
||||
block_tables_local=block_tables_local,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def get_chunked_attention_seqlen(
|
||||
self,
|
||||
cu_seqlen_q,
|
||||
seq_lens_np,
|
||||
block_tables,
|
||||
):
|
||||
attention_chunk_size = self.model.config.text_config.attention_chunk_size
|
||||
# seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
|
||||
# logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}")
|
||||
cu_seqlen_k = seqlen.cu_seqlen_k
|
||||
cu_seqlen_q = seqlen.cu_seqlen_q
|
||||
seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
virt_q_cu_seqlens_np,
|
||||
virt_k_seqlens_np,
|
||||
virt_block_table,
|
||||
) = make_local_attention_virtual_batches(
|
||||
self.model.config.text_config.attention_chunk_size,
|
||||
cu_seqlen_q.cpu().numpy(),
|
||||
attention_chunk_size,
|
||||
(
|
||||
cu_seqlen_q.cpu().numpy()
|
||||
if isinstance(cu_seqlen_q, torch.Tensor)
|
||||
else cu_seqlen_q
|
||||
),
|
||||
seq_lens_np.cpu().numpy(),
|
||||
block_tables,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
local_seqlen = Seqlen(
|
||||
input_lengths=torch.from_numpy(virt_k_seqlens_np).to(
|
||||
input_ids.device, non_blocking=True
|
||||
),
|
||||
cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to(
|
||||
input_ids.device, non_blocking=True
|
||||
),
|
||||
cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
input_ids.device, non_blocking=True
|
||||
),
|
||||
max_q=int(seqlens_q_local_np.max()),
|
||||
max_k=int(virt_k_seqlens_np.max()),
|
||||
|
||||
input_lengths = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
cu_seqlen_q.device, non_blocking=True
|
||||
)
|
||||
cache_lengths = torch.zeros(virt_k_seqlens_np.shape).to(
|
||||
cu_seqlen_q.device, non_blocking=True
|
||||
)
|
||||
seqlens_q_local = torch.from_numpy(seqlens_q_local_np).to(
|
||||
cu_seqlen_q.device, non_blocking=True
|
||||
)
|
||||
|
||||
inputs["local_seqlen"] = local_seqlen
|
||||
inputs["local_block_tables"] = virt_block_table
|
||||
max_q = int(seqlens_q_local_np.max())
|
||||
max_k = int(virt_k_seqlens_np.max())
|
||||
|
||||
return (
|
||||
input_lengths,
|
||||
cache_lengths,
|
||||
seqlens_q_local,
|
||||
max_q,
|
||||
max_k,
|
||||
virt_block_table,
|
||||
)
|
||||
|
||||
def pre_process_inputs(self, **kwargs):
|
||||
input_ids = kwargs["input_ids"]
|
||||
position_ids = kwargs["position_ids"]
|
||||
|
||||
inputs = super().pre_process_inputs(**kwargs)
|
||||
inputs["cache_position"] = position_ids
|
||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||
|
||||
return inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids, batch.image_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
cu_seqlen_q = (
|
||||
cu_seqlen_prefill
|
||||
if cu_seqlen_prefill is not None
|
||||
else torch.arange(
|
||||
input_lengths.shape[0] + 1, dtype=torch.int32, device=input_ids.device
|
||||
)
|
||||
)
|
||||
(
|
||||
input_lengths_tensor_local,
|
||||
cache_lengths_tensor_local,
|
||||
seqlens_q_local,
|
||||
max_q,
|
||||
max_k,
|
||||
block_tables_local,
|
||||
) = self.get_chunked_attention_seqlen(
|
||||
cu_seqlen_q=cu_seqlen_q,
|
||||
seq_lens_np=input_lengths + cache_lengths_tensor,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
raise RuntimeError("Flashinfer for LLama4 is not supported yet")
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
|
||||
cu_seqlens_q_local = F.pad(
|
||||
torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
|
||||
).to(torch.int32)
|
||||
seqlen_local = Seqlen(
|
||||
input_lengths=input_lengths_tensor_local,
|
||||
cache_lengths=cache_lengths_tensor_local,
|
||||
cu_seqlen_q=cu_seqlens_q_local,
|
||||
max_q=max_q,
|
||||
max_k=max_k,
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
seqlen_local=seqlen_local,
|
||||
block_tables_local=block_tables_local,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
raise RuntimeError("Flashinfer for LLama4 is not supported yet")
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
|
||||
cuda_graph["block_tables_local"][
|
||||
: block_tables_local.shape[0], : block_tables_local.shape[1]
|
||||
] = block_tables_local
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
cuda_graph["input_lengths_local"].zero_()
|
||||
cuda_graph["input_lengths_local"][
|
||||
: input_lengths_tensor_local.shape[0]
|
||||
] = input_lengths_tensor_local
|
||||
cuda_graph["cache_lengths_local"].zero_()
|
||||
cuda_graph["cache_lengths_local"][
|
||||
: cache_lengths_tensor_local.shape[0]
|
||||
] = cache_lengths_tensor_local
|
||||
cuda_graph["seqlens_q_local"].zero_()
|
||||
cuda_graph["seqlens_q_local"][: seqlens_q_local.shape[0]] = seqlens_q_local
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
Loading…
Reference in New Issue
Block a user