diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 461bbc2d..b20eae62 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional +from typing import List, Optional, Tuple, Dict import torch from opentelemetry import trace @@ -12,7 +12,8 @@ from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache -from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE +from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE, MEM_POOL +from text_generation_server.models.metadata_kernels import block_tables_to_ragged import torch.nn.functional as F import numpy as np @@ -172,13 +173,13 @@ def tgi_flash_attention_forward( sliding_window: Optional[int] = None, softcap: Optional[float] = None, use_sdpa: Optional[bool] = False, - local_seqlen: Optional[Seqlen] = None, - local_block_tables: Optional[torch.Tensor] = None, + seqlen_local: Optional[Seqlen] = None, + block_tables_local: Optional[torch.Tensor] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): if hasattr(module, "use_rope") and module.use_rope: - seqlen = local_seqlen - block_tables = local_block_tables + seqlen = seqlen_local + block_tables = block_tables_local kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) @@ -493,7 +494,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): image_grid_thw: Optional[torch.LongTensor] = None, pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + seqlen_local: Optional[Seqlen] = None, + block_tables_local: Optional[torch.Tensor] = None, ): + # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 @@ -505,13 +509,6 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): block_tables=block_tables, ) - if cu_seqlen_prefill is not None: - from loguru import logger - - logger.info( - f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}" - ) - # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( input_ids=inputs["input_ids"], @@ -535,8 +532,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): attention_mask=inputs.get("attention_mask", None), use_sdpa=inputs.get("use_sdpa", False), cache_position=inputs.get("cache_position", None), - local_seqlen=inputs.get("local_seqlen", None), - local_block_tables=inputs.get("local_block_tables", None), + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, ).logits logits = self.post_process_outputs(logits, lm_head_indices) @@ -707,48 +704,481 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): - def pre_process_inputs(self, **kwargs): - input_ids = kwargs["input_ids"] - position_ids = kwargs["position_ids"] - seqlen = kwargs["seqlen"] - block_tables = kwargs["block_tables"] + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + if max_bs is None: + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + config = getattr(self.model, "config", None) + rope_scaling = getattr(config, "rope_scaling", None) if config else None + if ( # mrope have position_ids per section, if so repeat n times + isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope" + ): + n_sections = len(self.model.config.rope_scaling["mrope_section"]) + position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) - inputs = super().pre_process_inputs(**kwargs) - inputs["cache_position"] = position_ids - inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) - # from loguru import logger + cu_seqlen_q = torch.arange( + input_lengths_tensor.shape[0] + 1, + device=self.device, + dtype=torch.int32, + ) + + ( + input_lengths_tensor_local, + cache_lengths_tensor_local, + seqlens_q_local, + max_q, + max_k, + block_tables_local, + ) = self.get_chunked_attention_seqlen( + cu_seqlen_q, + input_lengths_tensor, + block_tables, + ) + self.max_k_local = max_k + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] + + input_lengths_tensor_local = self.cuda_graphs[max_bs][ + "input_lengths_local" + ][:bs] + cache_lengths_tensor_local = self.cuda_graphs[max_bs][ + "cache_lengths_local" + ][:bs] + seqlens_q_local = self.cuda_graphs[max_bs]["seqlens_q_local"][:bs] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, + "input_lengths_local": input_lengths_tensor_local, + "cache_lengths_local": cache_lengths_tensor_local, + "seqlens_q_local": seqlens_q_local, + "block_tables_local": block_tables_local, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + # cu_seqlens_q_local = F.pad( + # torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0 + # ).to(torch.int32) + seqlen_local = Seqlen( + input_lengths=input_lengths_tensor_local, + cache_lengths=cache_lengths_tensor_local, + cu_seqlen_q=None, + max_q=1, + max_k=input_lengths_tensor_local.max(), + ) + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, + ) + del seqlen + del seqlen_local + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + # cu_seqlens_q_local = F.pad( + # torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0 + # ).to(torch.int32) + seqlen_local = Seqlen( + input_lengths=input_lengths_tensor_local, + cache_lengths=cache_lengths_tensor_local, + cu_seqlen_q=None, + max_q=1, + max_k=input_lengths_tensor_local.max(), + ) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + torch.cuda.synchronize() + + def get_chunked_attention_seqlen( + self, + cu_seqlen_q, + seq_lens_np, + block_tables, + ): + attention_chunk_size = self.model.config.text_config.attention_chunk_size + # seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] - # logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}") - cu_seqlen_k = seqlen.cu_seqlen_k - cu_seqlen_q = seqlen.cu_seqlen_q - seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] ( seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, virt_block_table, ) = make_local_attention_virtual_batches( - self.model.config.text_config.attention_chunk_size, - cu_seqlen_q.cpu().numpy(), + attention_chunk_size, + ( + cu_seqlen_q.cpu().numpy() + if isinstance(cu_seqlen_q, torch.Tensor) + else cu_seqlen_q + ), seq_lens_np.cpu().numpy(), block_tables, BLOCK_SIZE, ) - local_seqlen = Seqlen( - input_lengths=torch.from_numpy(virt_k_seqlens_np).to( - input_ids.device, non_blocking=True - ), - cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to( - input_ids.device, non_blocking=True - ), - cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to( - input_ids.device, non_blocking=True - ), - max_q=int(seqlens_q_local_np.max()), - max_k=int(virt_k_seqlens_np.max()), + + input_lengths = torch.from_numpy(virt_k_seqlens_np).to( + cu_seqlen_q.device, non_blocking=True + ) + cache_lengths = torch.zeros(virt_k_seqlens_np.shape).to( + cu_seqlen_q.device, non_blocking=True + ) + seqlens_q_local = torch.from_numpy(seqlens_q_local_np).to( + cu_seqlen_q.device, non_blocking=True ) - inputs["local_seqlen"] = local_seqlen - inputs["local_block_tables"] = virt_block_table + max_q = int(seqlens_q_local_np.max()) + max_k = int(virt_k_seqlens_np.max()) + + return ( + input_lengths, + cache_lengths, + seqlens_q_local, + max_q, + max_k, + virt_block_table, + ) + + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + + inputs = super().pre_process_inputs(**kwargs) + inputs["cache_position"] = position_ids + inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) return inputs + + def forward( + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if position_ids.dim() == 1 and batch.prefilling: + position_ids = self.model.get_position_ids( + input_ids, batch.image_grid_thw + ) + batch.position_ids = position_ids + + # Try to find an associated cuda graph + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + + cu_seqlen_q = ( + cu_seqlen_prefill + if cu_seqlen_prefill is not None + else torch.arange( + input_lengths.shape[0] + 1, dtype=torch.int32, device=input_ids.device + ) + ) + ( + input_lengths_tensor_local, + cache_lengths_tensor_local, + seqlens_q_local, + max_q, + max_k, + block_tables_local, + ) = self.get_chunked_attention_seqlen( + cu_seqlen_q=cu_seqlen_q, + seq_lens_np=input_lengths + cache_lengths_tensor, + block_tables=block_tables, + ) + + if cu_seqlen_prefill is not None or cuda_graph is None: + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + raise RuntimeError("Flashinfer for LLama4 is not supported yet") + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) + + cu_seqlens_q_local = F.pad( + torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0 + ).to(torch.int32) + seqlen_local = Seqlen( + input_lengths=input_lengths_tensor_local, + cache_lengths=cache_lengths_tensor_local, + cu_seqlen_q=cu_seqlens_q_local, + max_q=max_q, + max_k=max_k, + ) + + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + raise RuntimeError("Flashinfer for LLama4 is not supported yet") + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + + cuda_graph["block_tables_local"][ + : block_tables_local.shape[0], : block_tables_local.shape[1] + ] = block_tables_local + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor + cuda_graph["input_lengths_local"].zero_() + cuda_graph["input_lengths_local"][ + : input_lengths_tensor_local.shape[0] + ] = input_lengths_tensor_local + cuda_graph["cache_lengths_local"].zero_() + cuda_graph["cache_lengths_local"][ + : cache_lengths_tensor_local.shape[0] + ] = cache_lengths_tensor_local + cuda_graph["seqlens_q_local"].zero_() + cuda_graph["seqlens_q_local"][: seqlens_q_local.shape[0]] = seqlens_q_local + + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits