mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
Merge 2a10a28d08
into 449cee49ca
This commit is contained in:
commit
302c773c99
@ -158,7 +158,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
match config.model_type.as_deref() {
|
match config.model_type.as_deref() {
|
||||||
Some("falcon") | Some("deepseek_v2") => {
|
Some("falcon") | Some("deepseek_v2") | Some("llama4") => {
|
||||||
// Required because gemma2 needs bfloat16 which is not supported by
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
// flashinfer ?
|
// flashinfer ?
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -12,8 +12,10 @@ from text_generation_server.utils import initialize_torch_distributed
|
|||||||
|
|
||||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE, MEM_POOL
|
||||||
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -27,6 +29,126 @@ REPLICATED_ATTENTION_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def cdiv(a: int, b: int) -> int:
|
||||||
|
"""Ceiling division."""
|
||||||
|
return -(a // -b)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py
|
||||||
|
def make_local_attention_virtual_batches(
|
||||||
|
attn_chunk_size: int,
|
||||||
|
query_start_loc_np: np.ndarray,
|
||||||
|
seq_lens_np: np.ndarray,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
page_size: int = 0,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||||
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||||
|
actual_batch_size = seq_lens_np.shape[0]
|
||||||
|
# Handle if we are starting in the middle of a local attention block,
|
||||||
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||||
|
# the number of tokens that are not in the first local attention block and
|
||||||
|
# then we can simply use a cdiv for the rest.
|
||||||
|
# For example if we have:
|
||||||
|
# attn_chunk_size = 4
|
||||||
|
# q_seqlens = [4, 10, 5]
|
||||||
|
# k_seqlens = [6, 17, 9]
|
||||||
|
# Then we would get:
|
||||||
|
# new_tokens_in_first_block = [2, 1, 4]
|
||||||
|
# local_blocks = [2, 4, 2]
|
||||||
|
q_tokens_in_first_block = np.minimum(
|
||||||
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
|
||||||
|
).astype(np.int32)
|
||||||
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||||
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
|
||||||
|
|
||||||
|
# Once we know the number of local blocks we can compute the request spans
|
||||||
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||||
|
# have to make,
|
||||||
|
# For the above example we would get:
|
||||||
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||||
|
#
|
||||||
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||||
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||||
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||||
|
cu_num_blocks = np.cumsum(local_blocks)
|
||||||
|
virtual_batches = cu_num_blocks[-1]
|
||||||
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||||
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||||
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||||
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||||
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||||
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||||
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||||
|
# first and last blocks could be partial
|
||||||
|
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||||
|
# set the first block since this may be a partial block
|
||||||
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||||
|
# set the remaining blocks
|
||||||
|
seqlens_q_local[arange > 0] = np.minimum(
|
||||||
|
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
|
||||||
|
)[arange > 0]
|
||||||
|
|
||||||
|
# convert from q_seqlens to cu_seqlens_q
|
||||||
|
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32)
|
||||||
|
|
||||||
|
# compute the seqlens_k_local,
|
||||||
|
# basically a full local attention block for all but the last block in each
|
||||||
|
# batch
|
||||||
|
# For our example this will be:
|
||||||
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||||
|
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
|
||||||
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||||
|
|
||||||
|
if ATTENTION == "flashdecoding":
|
||||||
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
|
||||||
|
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
|
||||||
|
)
|
||||||
|
|
||||||
|
# For the example the local attention blocks start at:
|
||||||
|
# _b0_ _____b1_____ _b2_
|
||||||
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||||
|
block_starts = k_seqstarts_absolute // page_size
|
||||||
|
assert attn_chunk_size % page_size == 0, (
|
||||||
|
f"attn_chunk_size {attn_chunk_size} is not "
|
||||||
|
f"divisible by page_size {page_size}"
|
||||||
|
)
|
||||||
|
pages_per_local_batch = attn_chunk_size // page_size
|
||||||
|
|
||||||
|
# Create a block_table for the local attention blocks
|
||||||
|
# For out example if we have a block-table like (assuming page_size=2):
|
||||||
|
# block_table = [
|
||||||
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||||
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||||
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||||
|
# ]
|
||||||
|
# Then for the local batches we would want a block-table like
|
||||||
|
# block_table_local = [
|
||||||
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||||
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||||
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||||
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||||
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||||
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||||
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||||
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||||
|
# ]
|
||||||
|
block_indices = np.broadcast_to(
|
||||||
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||||
|
(virtual_batches, pages_per_local_batch),
|
||||||
|
) + np.expand_dims(block_starts, axis=1)
|
||||||
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||||
|
batch_indices = np.repeat(
|
||||||
|
np.arange(actual_batch_size, dtype=np.int32),
|
||||||
|
local_blocks * pages_per_local_batch,
|
||||||
|
)
|
||||||
|
block_table_local = block_table[batch_indices, block_indices].view(
|
||||||
|
virtual_batches, -1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
block_table_local = block_table
|
||||||
|
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local
|
||||||
|
|
||||||
|
|
||||||
# # Qwen2VL
|
# # Qwen2VL
|
||||||
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[
|
||||||
# "tgi"
|
# "tgi"
|
||||||
@ -51,8 +173,14 @@ def tgi_flash_attention_forward(
|
|||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
use_sdpa: Optional[bool] = False,
|
use_sdpa: Optional[bool] = False,
|
||||||
|
seqlen_local: Optional[Seqlen] = None,
|
||||||
|
block_tables_local: Optional[torch.Tensor] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
if hasattr(module, "use_rope") and module.use_rope:
|
||||||
|
seqlen = seqlen_local
|
||||||
|
block_tables = block_tables_local
|
||||||
|
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
@ -313,7 +441,9 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
def get_position_ids(self, input_ids, image_grid_thw, position_ids):
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids.unsqueeze(0),
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
"position_ids": position_ids.unsqueeze(0),
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
@ -364,7 +494,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
seqlen_local: Optional[Seqlen] = None,
|
||||||
|
block_tables_local: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
@ -372,7 +505,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
logits = self.model.original_forward(
|
logits = self.model.original_forward(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
@ -396,6 +532,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
attention_mask=inputs.get("attention_mask", None),
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
use_sdpa=inputs.get("use_sdpa", False),
|
use_sdpa=inputs.get("use_sdpa", False),
|
||||||
cache_position=inputs.get("cache_position", None),
|
cache_position=inputs.get("cache_position", None),
|
||||||
|
seqlen_local=seqlen_local,
|
||||||
|
block_tables_local=block_tables_local,
|
||||||
).logits
|
).logits
|
||||||
|
|
||||||
logits = self.post_process_outputs(logits, lm_head_indices)
|
logits = self.post_process_outputs(logits, lm_head_indices)
|
||||||
@ -480,7 +618,10 @@ class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
def post_process_outputs(self, logits, lm_head_indices):
|
def post_process_outputs(self, logits, lm_head_indices):
|
||||||
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)
|
||||||
|
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
|
|
||||||
input_ids = input_ids.unsqueeze(0)
|
input_ids = input_ids.unsqueeze(0)
|
||||||
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
position_ids = position_ids.transpose(0, 1).unsqueeze(1)
|
||||||
return {"input_ids": input_ids, "position_ids": position_ids}
|
return {"input_ids": input_ids, "position_ids": position_ids}
|
||||||
@ -542,7 +683,11 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
|
|
||||||
return final_attention_mask
|
return final_attention_mask
|
||||||
|
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
|
cu_seqlen_prefill = kwargs["cu_seqlen_prefill"]
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": input_ids.unsqueeze(0),
|
"input_ids": input_ids.unsqueeze(0),
|
||||||
"position_ids": position_ids.unsqueeze(0),
|
"position_ids": position_ids.unsqueeze(0),
|
||||||
@ -559,8 +704,482 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
|
|
||||||
|
|
||||||
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||||
def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill)
|
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||||
|
input_lengths = [max_s] * bs
|
||||||
|
cache_lengths = [0] * bs
|
||||||
|
if max_bs is None:
|
||||||
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
config = getattr(self.model, "config", None)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None) if config else None
|
||||||
|
if ( # mrope have position_ids per section, if so repeat n times
|
||||||
|
isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope"
|
||||||
|
):
|
||||||
|
n_sections = len(self.model.config.rope_scaling["mrope_section"])
|
||||||
|
position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)
|
||||||
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
|
input_lengths_tensor = (
|
||||||
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
bs, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
max_current_length=max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
input_lengths_tensor.shape[0] + 1,
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
input_lengths_tensor_local,
|
||||||
|
cache_lengths_tensor_local,
|
||||||
|
seqlens_q_local,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
block_tables_local,
|
||||||
|
) = self.get_chunked_attention_seqlen(
|
||||||
|
cu_seqlen_q,
|
||||||
|
input_lengths_tensor,
|
||||||
|
block_tables,
|
||||||
|
)
|
||||||
|
self.max_k_local = max_k
|
||||||
|
else:
|
||||||
|
if bs > max_bs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||||
|
)
|
||||||
|
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||||
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||||
|
else:
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||||
|
block_tables_local = self.cuda_graphs[max_bs]["block_tables_local"][:bs]
|
||||||
|
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||||
|
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||||
|
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||||
|
|
||||||
|
input_lengths_tensor_local = self.cuda_graphs[max_bs][
|
||||||
|
"input_lengths_local"
|
||||||
|
][:bs]
|
||||||
|
cache_lengths_tensor_local = self.cuda_graphs[max_bs][
|
||||||
|
"cache_lengths_local"
|
||||||
|
][:bs]
|
||||||
|
seqlens_q_local = self.cuda_graphs[max_bs]["seqlens_q_local"][:bs]
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_decode_state_cuda_graphs,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_tables_ptr = torch.zeros(
|
||||||
|
bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
|
state = create_decode_state_cuda_graphs(
|
||||||
|
device=input_ids.device,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_tables_ptr=block_tables_ptr,
|
||||||
|
last_page_len=last_page_len,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
self.cuda_graphs[bs] = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"kv_cache": self.kv_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"slots": slots,
|
||||||
|
"input_lengths": input_lengths_tensor,
|
||||||
|
"cache_lengths": cache_lengths_tensor,
|
||||||
|
"input_lengths_local": input_lengths_tensor_local,
|
||||||
|
"cache_lengths_local": cache_lengths_tensor_local,
|
||||||
|
"seqlens_q_local": seqlens_q_local,
|
||||||
|
"block_tables_local": block_tables_local,
|
||||||
|
"state": state,
|
||||||
|
"graph": graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
state=state,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
|
# cu_seqlens_q_local = F.pad(
|
||||||
|
# torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
|
||||||
|
# ).to(torch.int32)
|
||||||
|
seqlen_local = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor_local,
|
||||||
|
cache_lengths=cache_lengths_tensor_local,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=input_lengths_tensor_local.max(),
|
||||||
|
)
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
seqlen_local=seqlen_local,
|
||||||
|
block_tables_local=block_tables_local,
|
||||||
|
)
|
||||||
|
del seqlen
|
||||||
|
del seqlen_local
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
|
# cu_seqlens_q_local = F.pad(
|
||||||
|
# torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
|
||||||
|
# ).to(torch.int32)
|
||||||
|
seqlen_local = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor_local,
|
||||||
|
cache_lengths=cache_lengths_tensor_local,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=input_lengths_tensor_local.max(),
|
||||||
|
)
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
seqlen_local=seqlen_local,
|
||||||
|
block_tables_local=block_tables_local,
|
||||||
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def get_chunked_attention_seqlen(
|
||||||
|
self,
|
||||||
|
cu_seqlen_q,
|
||||||
|
seq_lens_np,
|
||||||
|
block_tables,
|
||||||
|
):
|
||||||
|
attention_chunk_size = self.model.config.text_config.attention_chunk_size
|
||||||
|
# seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
|
||||||
|
|
||||||
|
(
|
||||||
|
seqlens_q_local_np,
|
||||||
|
virt_q_cu_seqlens_np,
|
||||||
|
virt_k_seqlens_np,
|
||||||
|
virt_block_table,
|
||||||
|
) = make_local_attention_virtual_batches(
|
||||||
|
attention_chunk_size,
|
||||||
|
(
|
||||||
|
cu_seqlen_q.cpu().numpy()
|
||||||
|
if isinstance(cu_seqlen_q, torch.Tensor)
|
||||||
|
else cu_seqlen_q
|
||||||
|
),
|
||||||
|
seq_lens_np.cpu().numpy(),
|
||||||
|
block_tables,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_lengths = torch.from_numpy(virt_k_seqlens_np).to(
|
||||||
|
cu_seqlen_q.device, non_blocking=True
|
||||||
|
)
|
||||||
|
cache_lengths = torch.zeros(virt_k_seqlens_np.shape).to(
|
||||||
|
cu_seqlen_q.device, non_blocking=True
|
||||||
|
)
|
||||||
|
seqlens_q_local = torch.from_numpy(seqlens_q_local_np).to(
|
||||||
|
cu_seqlen_q.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
max_q = int(seqlens_q_local_np.max())
|
||||||
|
max_k = int(virt_k_seqlens_np.max())
|
||||||
|
|
||||||
|
return (
|
||||||
|
input_lengths,
|
||||||
|
cache_lengths,
|
||||||
|
seqlens_q_local,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
virt_block_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pre_process_inputs(self, **kwargs):
|
||||||
|
input_ids = kwargs["input_ids"]
|
||||||
|
position_ids = kwargs["position_ids"]
|
||||||
|
|
||||||
|
inputs = super().pre_process_inputs(**kwargs)
|
||||||
inputs["cache_position"] = position_ids
|
inputs["cache_position"] = position_ids
|
||||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: VlmCausalLMBatch,
|
||||||
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_current_length
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
cache_lengths_tensor = (
|
||||||
|
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
|
max_s = batch.max_current_length
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||||
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
|
position_ids = self.model.get_position_ids(
|
||||||
|
input_ids, batch.image_grid_thw
|
||||||
|
)
|
||||||
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
|
|
||||||
|
cu_seqlen_q = (
|
||||||
|
cu_seqlen_prefill
|
||||||
|
if cu_seqlen_prefill is not None
|
||||||
|
else torch.arange(
|
||||||
|
input_lengths.shape[0] + 1, dtype=torch.int32, device=input_ids.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
input_lengths_tensor_local,
|
||||||
|
cache_lengths_tensor_local,
|
||||||
|
seqlens_q_local,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
block_tables_local,
|
||||||
|
) = self.get_chunked_attention_seqlen(
|
||||||
|
cu_seqlen_q=cu_seqlen_q,
|
||||||
|
seq_lens_np=input_lengths + cache_lengths_tensor,
|
||||||
|
block_tables=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
|
)
|
||||||
|
raise RuntimeError("Flashinfer for LLama4 is not supported yet")
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=batch.max_input_length,
|
||||||
|
max_k=batch.max_current_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q_local = F.pad(
|
||||||
|
torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0
|
||||||
|
).to(torch.int32)
|
||||||
|
seqlen_local = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor_local,
|
||||||
|
cache_lengths=cache_lengths_tensor_local,
|
||||||
|
cu_seqlen_q=cu_seqlens_q_local,
|
||||||
|
max_q=max_q,
|
||||||
|
max_k=max_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
|
image_sizes=batch.image_sizes,
|
||||||
|
image_grid_thw=batch.image_grid_thw,
|
||||||
|
seqlen_local=seqlen_local,
|
||||||
|
block_tables_local=block_tables_local,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
|
if batch.image_grid_thw is not None:
|
||||||
|
batch.image_grid_thw = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
raise RuntimeError("Flashinfer for LLama4 is not supported yet")
|
||||||
|
else:
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
|
||||||
|
cuda_graph["block_tables_local"][
|
||||||
|
: block_tables_local.shape[0], : block_tables_local.shape[1]
|
||||||
|
] = block_tables_local
|
||||||
|
|
||||||
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||||
|
# so it doesn't matter if we override it with bogus values.
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
cuda_graph["cache_lengths"].zero_()
|
||||||
|
cuda_graph["cache_lengths"][
|
||||||
|
: cache_lengths_tensor.shape[0]
|
||||||
|
] = cache_lengths_tensor
|
||||||
|
cuda_graph["input_lengths_local"].zero_()
|
||||||
|
cuda_graph["input_lengths_local"][
|
||||||
|
: input_lengths_tensor_local.shape[0]
|
||||||
|
] = input_lengths_tensor_local
|
||||||
|
cuda_graph["cache_lengths_local"].zero_()
|
||||||
|
cuda_graph["cache_lengths_local"][
|
||||||
|
: cache_lengths_tensor_local.shape[0]
|
||||||
|
] = cache_lengths_tensor_local
|
||||||
|
cuda_graph["seqlens_q_local"].zero_()
|
||||||
|
cuda_graph["seqlens_q_local"][: seqlens_q_local.shape[0]] = seqlens_q_local
|
||||||
|
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||||
|
state=cuda_graph["state"],
|
||||||
|
):
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
Loading…
Reference in New Issue
Block a user