From e5618d6e40b7bca3b0c89ef22b2335df22863558 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 9 Apr 2025 16:36:06 +0000 Subject: [PATCH] add chunked attn support --- .../models/transformers_flash_vlm.py | 189 +++++++++++++++++- 1 file changed, 183 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index a7beb68b3..dae0e3346 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -12,8 +12,9 @@ from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache -from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE import torch.nn.functional as F +import numpy as np tracer = trace.get_tracer(__name__) @@ -27,6 +28,124 @@ REPLICATED_ATTENTION_MODELS = [ ] +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + page_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) + + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // page_size + assert attn_chunk_size % page_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not " + f"divisible by page_size {page_size}" + ) + pages_per_local_batch = attn_chunk_size // page_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming page_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices = np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch), + ) + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) + block_table_local = block_table[batch_indices, block_indices].view( + virtual_batches, -1 + ) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local + + # # Qwen2VL # transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ # "tgi" @@ -51,8 +170,13 @@ def tgi_flash_attention_forward( sliding_window: Optional[int] = None, softcap: Optional[float] = None, use_sdpa: Optional[bool] = False, + local_seqlen: Optional[Seqlen] = None, + local_block_tables: Optional[torch.Tensor] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): + if module.use_rope: + seqlen = local_seqlen + block_tables = local_block_tables kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) @@ -313,7 +437,9 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): def get_position_ids(self, input_ids, image_grid_thw, position_ids): return position_ids - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] return { "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), @@ -372,6 +498,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, + seqlen=seqlen, + block_tables=block_tables, ) # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( @@ -396,6 +524,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): attention_mask=inputs.get("attention_mask", None), use_sdpa=inputs.get("use_sdpa", False), cache_position=inputs.get("cache_position", None), + local_seqlen=inputs.get("local_seqlen", None), + local_block_tables=inputs.get("local_block_tables", None), ).logits logits = self.post_process_outputs(logits, lm_head_indices) @@ -480,7 +610,10 @@ class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM): def post_process_outputs(self, logits, lm_head_indices): return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0) - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + input_ids = input_ids.unsqueeze(0) position_ids = position_ids.transpose(0, 1).unsqueeze(1) return {"input_ids": input_ids, "position_ids": position_ids} @@ -542,7 +675,11 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): return final_attention_mask - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + cu_seqlen_prefill = kwargs["cu_seqlen_prefill"] + inputs = { "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), @@ -559,8 +696,48 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): - inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill) + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + seqlen = kwargs["seqlen"] + block_tables = kwargs["block_tables"] + + inputs = super().pre_process_inputs(**kwargs) inputs["cache_position"] = position_ids inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) + from loguru import logger + + logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}") + cu_seqlen_k = seqlen.cu_seqlen_k + cu_seqlen_q = seqlen.cu_seqlen_q + seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] + ( + seqlens_q_local_np, + virt_q_cu_seqlens_np, + virt_k_seqlens_np, + virt_block_table, + ) = make_local_attention_virtual_batches( + self.model.config.text_config.attention_chunk_size, + cu_seqlen_q.cpu().numpy(), + seq_lens_np.cpu().numpy(), + block_tables, + BLOCK_SIZE, + ) + local_seqlen = Seqlen( + input_lengths=torch.from_numpy(virt_k_seqlens_np).to( + input_ids.device, non_blocking=True + ), + cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to( + input_ids.device, non_blocking=True + ), + cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to( + input_ids.device, non_blocking=True + ), + max_q=int(seqlens_q_local_np.max()), + max_k=int(virt_k_seqlens_np.max()), + ) + + inputs["local_seqlen"] = local_seqlen + inputs["local_block_tables"] = virt_block_table + return inputs