From 33a7ec57e2ca1c8d11baa3faf5a121ff06c01398 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 10 Apr 2025 14:59:39 +0000 Subject: [PATCH 1/5] add fix --- .../models/transformers_flash_causal_lm.py | 543 ++++++++++++++++-- 1 file changed, 505 insertions(+), 38 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 77659dd0..461bbc2d 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -3,20 +3,157 @@ from typing import List, Optional import torch from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoProcessor import transformers.modeling_utils from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache -from text_generation_server.models.globals import ATTENTION - +from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE +import torch.nn.functional as F +import numpy as np tracer = trace.get_tracer(__name__) +# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, +# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache +# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due +# to internal constraints it was not (yet?) possible to circumvent +REPLICATED_ATTENTION_MODELS = [ + "olmo2", + "phi3", +] + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + page_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + if ATTENTION == "flashdecoding": + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) + + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // page_size + assert attn_chunk_size % page_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not " + f"divisible by page_size {page_size}" + ) + pages_per_local_batch = attn_chunk_size // page_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming page_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices = np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch), + ) + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) + block_table_local = block_table[batch_indices, block_indices].view( + virtual_batches, -1 + ) + else: + block_table_local = block_table + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local + + +# # Qwen2VL +# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ +# "tgi" +# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ +# "eager" +# ] def tgi_flash_attention_forward( module, query_states: torch.Tensor, @@ -34,8 +171,14 @@ def tgi_flash_attention_forward( softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, + use_sdpa: Optional[bool] = False, + local_seqlen: Optional[Seqlen] = None, + local_block_tables: Optional[torch.Tensor] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): + if hasattr(module, "use_rope") and module.use_rope: + seqlen = local_seqlen + block_tables = local_block_tables kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) @@ -50,18 +193,60 @@ def tgi_flash_attention_forward( sliding_window = -1 if sliding_window is None else sliding_window if cu_seqlen_prefill is not None: - attn_output = attention( - query=query_states, - key=key_states, - value=value_states, - kv_cache=kv_cache, - kv_scales=kv_scales, - seqlen=seqlen, - block_tables=block_tables, - softmax_scale=softmax_scale, - window_size_left=sliding_window, - softcap=softcap, - ) + if not use_sdpa: + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) + else: + lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1] + max_length = max(lengths) + attention_mask = attention_mask[:, :, :, :max_length] + enable_gqa = query_states.shape[1] != key_states.shape[1] + # Split tensors using vectorized split + query_list = torch.split(query_states, lengths.tolist(), dim=0) + key_list = torch.split(key_states, lengths.tolist(), dim=0) + value_list = torch.split(value_states, lengths.tolist(), dim=0) + + padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True) + padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True) + padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True) + + padded_query = padded_query.transpose(1, 2).contiguous() + padded_key = padded_key.transpose(1, 2).contiguous() + padded_value = padded_value.transpose(1, 2).contiguous() + + # Compute attention + attn_output = F.scaled_dot_product_attention( + padded_query, + padded_key, + padded_value, + attn_mask=attention_mask, + scale=softmax_scale, + enable_gqa=enable_gqa, + ) + + attn_output = attn_output.transpose( + 1, 2 + ) # [batch_size, seq_len, num_heads, head_dim] + max_seq_len = padded_query.size(2) + seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze( + 0 + ) + lengths_tensor = torch.tensor( + lengths, device=padded_query.device + ).unsqueeze(1) + mask = seq_range < lengths_tensor # [batch, max_seq_len] + attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim] + else: attn_output = paged_attention( query_states, @@ -83,20 +268,16 @@ def tgi_flash_attention_forward( transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward -# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, -# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache -# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due -# to internal constraints it was not (yet?) possible to circumvent -REPLICATED_ATTENTION_MODELS = [ - "olmo2", - "phi3", -] + +# TODO: implement +# tgi_cross_attention_forward -class TransformersFlashCausalLM(FlashCausalLM): +class TransformersFlashVlmCausalLM(VlmCausalLM): def __init__( self, model_id: str, + model_class, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -104,10 +285,15 @@ class TransformersFlashCausalLM(FlashCausalLM): default_dtype=torch.float16, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, + processor_class=AutoProcessor, + processor_kwargs=None, kv_cache_dtype: Optional[torch.dtype] = None, + batch_class=VlmCausalLMBatch, ): + self.batch_class = batch_class self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() + self.dtype = dtype if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") @@ -131,18 +317,40 @@ class TransformersFlashCausalLM(FlashCausalLM): trust_remote_code=trust_remote_code, ) - model = AutoModelForCausalLM.from_pretrained( + if processor_kwargs is None: + processor_kwargs = {} + + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + + attn_implementation = { + "text_config": "tgi", + "vision_config": "sdpa", + } + + model = model_class.from_pretrained( model_id, revision=revision, torch_dtype=dtype, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation="tgi", + attn_implementation=attn_implementation, device_map=device if world_size == 1 else None, tp_plan="auto" if world_size > 1 else None, ) torch.distributed.barrier(group=self.process_group) + self.config = model.config + config = model.config + + # VLM models define the config we care about in their text_config + text_config = getattr(model.config, "text_config", None) + if text_config is not None: + config = text_config if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: @@ -156,20 +364,18 @@ class TransformersFlashCausalLM(FlashCausalLM): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.num_layers = model.config.num_hidden_layers - self.num_heads = model.config.num_attention_heads - self.num_kv_heads = model.config.num_key_value_heads + self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads # Some models use GQA and different sizes for o_proj # and q_proj, that allows for that. - if hasattr(model.config, "head_dim"): - self.head_size = model.config.head_dim + if hasattr(config, "head_dim"): + self.head_size = config.head_dim else: - self.head_size = ( - model.config.hidden_size // model.config.num_attention_heads - ) + self.head_size = config.hidden_size // config.num_attention_heads # Skip it for models in the exception list - if model.config.model_type not in REPLICATED_ATTENTION_MODELS: + if config.model_type not in REPLICATED_ATTENTION_MODELS: self.num_heads = self.num_heads // self.process_group.size() self.num_kv_heads = ( self.num_kv_heads // self.process_group.size() @@ -227,26 +433,47 @@ class TransformersFlashCausalLM(FlashCausalLM): # We first copy the original model.forward because we still need it in the monkey patch self.model.original_forward = self.model.forward self.model.forward = self._model_forward + self.model.get_position_ids = self.get_position_ids torch.distributed.barrier(group=self.process_group) + def get_position_ids(self, input_ids, image_grid_thw, position_ids): + return position_ids + + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + return { + "input_ids": input_ids.unsqueeze(0), + "position_ids": position_ids.unsqueeze(0), + } + + def post_process_outputs(self, logits, lm_head_indices): + return logits.squeeze(dim=0) + @classmethod def fallback( cls, model_id: str, + model_class, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + batch_class: Optional[type] = VlmCausalLMBatch, + processor_kwargs: Optional[dict] = None, ): return cls( model_id=model_id, + model_class=model_class, revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=batch_class, + processor_kwargs=processor_kwargs, ) def _model_forward( @@ -262,14 +489,33 @@ class TransformersFlashCausalLM(FlashCausalLM): lm_head_indices: Optional[torch.Tensor], prefill_cache_indices=None, # not used, but passed to match original signature adapter_data=None, # not supported, but passed to match original signature + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, ): # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 + inputs = self.pre_process_inputs( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + seqlen=seqlen, + block_tables=block_tables, + ) + + if cu_seqlen_prefill is not None: + from loguru import logger + + logger.info( + f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}" + ) + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( - input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers - position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers + input_ids=inputs["input_ids"], + position_ids=inputs["position_ids"], past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object logits_to_keep=logits_to_keep, @@ -282,6 +528,227 @@ class TransformersFlashCausalLM(FlashCausalLM): max_s=max_s, kv_head_mapping=self.kv_head_mapping, kv_scales=self.kv_scales, - ).logits.squeeze(dim=0) + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_sizes=image_sizes, + image_grid_thw=image_grid_thw, + attention_mask=inputs.get("attention_mask", None), + use_sdpa=inputs.get("use_sdpa", False), + cache_position=inputs.get("cache_position", None), + local_seqlen=inputs.get("local_seqlen", None), + local_block_tables=inputs.get("local_block_tables", None), + ).logits + + logits = self.post_process_outputs(logits, lm_head_indices) return logits, None + + +class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM): + def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor): + if image_grid_thw is None: + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + spatial_merge_size = self.config.vision_config.spatial_merge_size + vision_start_token_id = self.config.vision_start_token_id + vision_end_token_id = self.config.vision_end_token_id + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + vision_starts = torch.where(input_ids == vision_start_token_id)[0] + vision_ends = torch.where(input_ids == vision_end_token_id)[0] + vision_segments = torch.stack((vision_starts, vision_ends), dim=1) + prev_vision_end = torch.cat( + [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] + ) + text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + vision_segment_lengths = vision_widths_max + text_lengths_between_vision + vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) + text_segment_lengths = vision_segment_lengths - text_lengths_between_vision + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + vision_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_segment_lengths[i] + for i, seq_len in enumerate(text_lengths_between_vision) + ] + + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + # import ipdb + + # ipdb.set_trace() + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - vision_ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def post_process_outputs(self, logits, lm_head_indices): + return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0) + + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + + input_ids = input_ids.unsqueeze(0) + position_ids = position_ids.transpose(0, 1).unsqueeze(1) + return {"input_ids": input_ids, "position_ids": position_ids} + + +class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): + def get_attention_mask(self, input_ids, cu_seqlen_prefill): + device = input_ids.device + dtype = self.dtype + min_dtype = torch.finfo(dtype).min + + lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist() + batch_size = len(lengths) + + sequence_length = max(lengths) + target_length = sequence_length + # Create the padding mask from the computed lengths. + # pad_mask: [batch, sequence_length] where True indicates valid tokens. + seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) + lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1) + pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length] + + # Build the base causal mask (for non-image tokens): + causal_mask = torch.tril( + torch.ones( + (sequence_length, sequence_length), dtype=torch.bool, device=device + ) + ) + base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze( + 1 + ) # [batch, sequence_length, sequence_length] + base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint + + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + + image_token_mask = torch.nn.utils.rnn.pad_sequence( + torch.split(image_token_mask, lengths), batch_first=True, padding_value=0 + ) + bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze( + 1 + ) + + # Combine the causal base mask and the bidirectional mask. + combined_mask = torch.logical_or( + base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1) + ).to(device) + # combined_mask now has shape [batch, 1, sequence_length, sequence_length] + + full_attention_mask = torch.zeros( + (batch_size, 1, sequence_length, target_length), + device=device, + dtype=torch.bool, + ) + full_attention_mask[:, :, :, :sequence_length] = combined_mask + + final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) + + return final_attention_mask + + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + cu_seqlen_prefill = kwargs["cu_seqlen_prefill"] + + inputs = { + "input_ids": input_ids.unsqueeze(0), + "position_ids": position_ids.unsqueeze(0), + } + + if cu_seqlen_prefill is not None: + attention_mask = self.get_attention_mask( + input_ids.squeeze(0), cu_seqlen_prefill + ) + inputs["attention_mask"] = attention_mask + inputs["use_sdpa"] = True + + return inputs + + +class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + seqlen = kwargs["seqlen"] + block_tables = kwargs["block_tables"] + + inputs = super().pre_process_inputs(**kwargs) + inputs["cache_position"] = position_ids + inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) + # from loguru import logger + + # logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}") + cu_seqlen_k = seqlen.cu_seqlen_k + cu_seqlen_q = seqlen.cu_seqlen_q + seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] + ( + seqlens_q_local_np, + virt_q_cu_seqlens_np, + virt_k_seqlens_np, + virt_block_table, + ) = make_local_attention_virtual_batches( + self.model.config.text_config.attention_chunk_size, + cu_seqlen_q.cpu().numpy(), + seq_lens_np.cpu().numpy(), + block_tables, + BLOCK_SIZE, + ) + local_seqlen = Seqlen( + input_lengths=torch.from_numpy(virt_k_seqlens_np).to( + input_ids.device, non_blocking=True + ), + cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to( + input_ids.device, non_blocking=True + ), + cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to( + input_ids.device, non_blocking=True + ), + max_q=int(seqlens_q_local_np.max()), + max_k=int(virt_k_seqlens_np.max()), + ) + + inputs["local_seqlen"] = local_seqlen + inputs["local_block_tables"] = virt_block_table + + return inputs From 3f343cdb6f7866f8d71ca1653f85ac8e6964cb33 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 10 Apr 2025 15:03:44 +0000 Subject: [PATCH 2/5] reverse flash causal change --- .../models/transformers_flash_causal_lm.py | 543 ++---------------- .../models/transformers_flash_vlm.py | 200 ++++++- 2 files changed, 232 insertions(+), 511 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 461bbc2d..77659dd0 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -3,157 +3,20 @@ from typing import List, Optional import torch from opentelemetry import trace -from transformers import AutoTokenizer, AutoProcessor +from transformers import AutoTokenizer, AutoModelForCausalLM import transformers.modeling_utils from text_generation_server.models.flash_causal_lm import FlashCausalLM -from text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache -from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE -import torch.nn.functional as F -import numpy as np +from text_generation_server.models.globals import ATTENTION + tracer = trace.get_tracer(__name__) -# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, -# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache -# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due -# to internal constraints it was not (yet?) possible to circumvent -REPLICATED_ATTENTION_MODELS = [ - "olmo2", - "phi3", -] - -def cdiv(a: int, b: int) -> int: - """Ceiling division.""" - return -(a // -b) - - -# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py -def make_local_attention_virtual_batches( - attn_chunk_size: int, - query_start_loc_np: np.ndarray, - seq_lens_np: np.ndarray, - block_table: torch.Tensor, - page_size: int = 0, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: - q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] - actual_batch_size = seq_lens_np.shape[0] - # Handle if we are starting in the middle of a local attention block, - # we assume q_seqlens > 0 (for all elements), for each batch idx we compute - # the number of tokens that are not in the first local attention block and - # then we can simply use a cdiv for the rest. - # For example if we have: - # attn_chunk_size = 4 - # q_seqlens = [4, 10, 5] - # k_seqlens = [6, 17, 9] - # Then we would get: - # new_tokens_in_first_block = [2, 1, 4] - # local_blocks = [2, 4, 2] - q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens - ).astype(np.int32) - tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) - - # Once we know the number of local blocks we can compute the request spans - # for each batch idx, we can figure out the number of "virtual" requests we - # have to make, - # For the above example we would get: - # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] - # - # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) - # (TODO: max a utility to share this code with _prepare_inputs) - # arange step 1. [2, 4, 2] -> [2, 6, 8] - cu_num_blocks = np.cumsum(local_blocks) - virtual_batches = cu_num_blocks[-1] - # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] - block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) - # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] - arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets - # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) - rarange = np.repeat(local_blocks, local_blocks) - arange - 1 - # Then we can compute the seqlens_q_local, handling the fact that the - # first and last blocks could be partial - seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) - # set the first block since this may be a partial block - seqlens_q_local[arange == 0] = q_tokens_in_first_block - # set the remaining blocks - seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size - )[arange > 0] - - # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32) - - # compute the seqlens_k_local, - # basically a full local attention block for all but the last block in each - # batch - # For our example this will be: - # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) - seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block - - if ATTENTION == "flashdecoding": - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( - rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) - ) - - # For the example the local attention blocks start at: - # _b0_ _____b1_____ _b2_ - # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] - block_starts = k_seqstarts_absolute // page_size - assert attn_chunk_size % page_size == 0, ( - f"attn_chunk_size {attn_chunk_size} is not " - f"divisible by page_size {page_size}" - ) - pages_per_local_batch = attn_chunk_size // page_size - - # Create a block_table for the local attention blocks - # For out example if we have a block-table like (assuming page_size=2): - # block_table = [ - # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 - # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 - # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 - # ] - # Then for the local batches we would want a block-table like - # block_table_local = [ - # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) - # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) - # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) - # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) - # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) - # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) - # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) - # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) - # ] - block_indices = np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch), - ) + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) - batch_indices = np.repeat( - np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch, - ) - block_table_local = block_table[batch_indices, block_indices].view( - virtual_batches, -1 - ) - else: - block_table_local = block_table - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local - - -# # Qwen2VL -# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ -# "tgi" -# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ -# "eager" -# ] def tgi_flash_attention_forward( module, query_states: torch.Tensor, @@ -171,14 +34,8 @@ def tgi_flash_attention_forward( softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - use_sdpa: Optional[bool] = False, - local_seqlen: Optional[Seqlen] = None, - local_block_tables: Optional[torch.Tensor] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): - if hasattr(module, "use_rope") and module.use_rope: - seqlen = local_seqlen - block_tables = local_block_tables kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) @@ -193,60 +50,18 @@ def tgi_flash_attention_forward( sliding_window = -1 if sliding_window is None else sliding_window if cu_seqlen_prefill is not None: - if not use_sdpa: - attn_output = attention( - query=query_states, - key=key_states, - value=value_states, - kv_cache=kv_cache, - kv_scales=kv_scales, - seqlen=seqlen, - block_tables=block_tables, - softmax_scale=softmax_scale, - window_size_left=sliding_window, - softcap=softcap, - ) - else: - lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1] - max_length = max(lengths) - attention_mask = attention_mask[:, :, :, :max_length] - enable_gqa = query_states.shape[1] != key_states.shape[1] - # Split tensors using vectorized split - query_list = torch.split(query_states, lengths.tolist(), dim=0) - key_list = torch.split(key_states, lengths.tolist(), dim=0) - value_list = torch.split(value_states, lengths.tolist(), dim=0) - - padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True) - padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True) - padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True) - - padded_query = padded_query.transpose(1, 2).contiguous() - padded_key = padded_key.transpose(1, 2).contiguous() - padded_value = padded_value.transpose(1, 2).contiguous() - - # Compute attention - attn_output = F.scaled_dot_product_attention( - padded_query, - padded_key, - padded_value, - attn_mask=attention_mask, - scale=softmax_scale, - enable_gqa=enable_gqa, - ) - - attn_output = attn_output.transpose( - 1, 2 - ) # [batch_size, seq_len, num_heads, head_dim] - max_seq_len = padded_query.size(2) - seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze( - 0 - ) - lengths_tensor = torch.tensor( - lengths, device=padded_query.device - ).unsqueeze(1) - mask = seq_range < lengths_tensor # [batch, max_seq_len] - attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim] - + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) else: attn_output = paged_attention( query_states, @@ -268,16 +83,20 @@ def tgi_flash_attention_forward( transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward - -# TODO: implement -# tgi_cross_attention_forward +# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, +# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache +# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due +# to internal constraints it was not (yet?) possible to circumvent +REPLICATED_ATTENTION_MODELS = [ + "olmo2", + "phi3", +] -class TransformersFlashVlmCausalLM(VlmCausalLM): +class TransformersFlashCausalLM(FlashCausalLM): def __init__( self, model_id: str, - model_class, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -285,15 +104,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): default_dtype=torch.float16, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, - processor_class=AutoProcessor, - processor_kwargs=None, kv_cache_dtype: Optional[torch.dtype] = None, - batch_class=VlmCausalLMBatch, ): - self.batch_class = batch_class self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() - self.dtype = dtype if speculator: raise RuntimeError("Speculator decoding is not enabled for AutoModel") @@ -317,40 +131,18 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): trust_remote_code=trust_remote_code, ) - if processor_kwargs is None: - processor_kwargs = {} - - self.processor = processor_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - **processor_kwargs, - ) - - attn_implementation = { - "text_config": "tgi", - "vision_config": "sdpa", - } - - model = model_class.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, - attn_implementation=attn_implementation, + attn_implementation="tgi", device_map=device if world_size == 1 else None, tp_plan="auto" if world_size > 1 else None, ) torch.distributed.barrier(group=self.process_group) - self.config = model.config - config = model.config - - # VLM models define the config we care about in their text_config - text_config = getattr(model.config, "text_config", None) - if text_config is not None: - config = text_config if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: @@ -364,18 +156,20 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - self.num_layers = config.num_hidden_layers - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads + self.num_layers = model.config.num_hidden_layers + self.num_heads = model.config.num_attention_heads + self.num_kv_heads = model.config.num_key_value_heads # Some models use GQA and different sizes for o_proj # and q_proj, that allows for that. - if hasattr(config, "head_dim"): - self.head_size = config.head_dim + if hasattr(model.config, "head_dim"): + self.head_size = model.config.head_dim else: - self.head_size = config.hidden_size // config.num_attention_heads + self.head_size = ( + model.config.hidden_size // model.config.num_attention_heads + ) # Skip it for models in the exception list - if config.model_type not in REPLICATED_ATTENTION_MODELS: + if model.config.model_type not in REPLICATED_ATTENTION_MODELS: self.num_heads = self.num_heads // self.process_group.size() self.num_kv_heads = ( self.num_kv_heads // self.process_group.size() @@ -433,47 +227,26 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): # We first copy the original model.forward because we still need it in the monkey patch self.model.original_forward = self.model.forward self.model.forward = self._model_forward - self.model.get_position_ids = self.get_position_ids torch.distributed.barrier(group=self.process_group) - def get_position_ids(self, input_ids, image_grid_thw, position_ids): - return position_ids - - def pre_process_inputs(self, **kwargs): - input_ids = kwargs["input_ids"] - position_ids = kwargs["position_ids"] - return { - "input_ids": input_ids.unsqueeze(0), - "position_ids": position_ids.unsqueeze(0), - } - - def post_process_outputs(self, logits, lm_head_indices): - return logits.squeeze(dim=0) - @classmethod def fallback( cls, model_id: str, - model_class, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - batch_class: Optional[type] = VlmCausalLMBatch, - processor_kwargs: Optional[dict] = None, ): return cls( model_id=model_id, - model_class=model_class, revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, - batch_class=batch_class, - processor_kwargs=processor_kwargs, ) def _model_forward( @@ -489,33 +262,14 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): lm_head_indices: Optional[torch.Tensor], prefill_cache_indices=None, # not used, but passed to match original signature adapter_data=None, # not supported, but passed to match original signature - pixel_values: torch.FloatTensor = None, - image_grid_thw: Optional[torch.LongTensor] = None, - pixel_attention_mask=None, - image_sizes: Optional[torch.LongTensor] = None, ): # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 - inputs = self.pre_process_inputs( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - seqlen=seqlen, - block_tables=block_tables, - ) - - if cu_seqlen_prefill is not None: - from loguru import logger - - logger.info( - f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}" - ) - # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( - input_ids=inputs["input_ids"], - position_ids=inputs["position_ids"], + input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers + position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object logits_to_keep=logits_to_keep, @@ -528,227 +282,6 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): max_s=max_s, kv_head_mapping=self.kv_head_mapping, kv_scales=self.kv_scales, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_sizes=image_sizes, - image_grid_thw=image_grid_thw, - attention_mask=inputs.get("attention_mask", None), - use_sdpa=inputs.get("use_sdpa", False), - cache_position=inputs.get("cache_position", None), - local_seqlen=inputs.get("local_seqlen", None), - local_block_tables=inputs.get("local_block_tables", None), - ).logits - - logits = self.post_process_outputs(logits, lm_head_indices) + ).logits.squeeze(dim=0) return logits, None - - -class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM): - def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor): - if image_grid_thw is None: - return ( - torch.arange(input_ids.shape[0], device=input_ids.device) - .unsqueeze(1) - .repeat(1, 3) - ) - - spatial_merge_size = self.config.vision_config.spatial_merge_size - vision_start_token_id = self.config.vision_start_token_id - vision_end_token_id = self.config.vision_end_token_id - device = input_ids.device - dtype = input_ids.dtype - input_ids_len = input_ids.shape[0] - - vision_starts = torch.where(input_ids == vision_start_token_id)[0] - vision_ends = torch.where(input_ids == vision_end_token_id)[0] - vision_segments = torch.stack((vision_starts, vision_ends), dim=1) - prev_vision_end = torch.cat( - [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] - ) - text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 - vision_widths_max = torch.cat( - [ - torch.zeros(1, device=image_grid_thw.device, dtype=dtype), - image_grid_thw[:-1, 2] // spatial_merge_size, - ] - ) - vision_segment_lengths = vision_widths_max + text_lengths_between_vision - vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) - text_segment_lengths = vision_segment_lengths - text_lengths_between_vision - - # create position ids for each vision segment based on the image grid - llm_pos_ids_list = [] - for i, _ in enumerate(vision_segments): - t, h, w = ( - image_grid_thw[i][0], - image_grid_thw[i][1] // spatial_merge_size, - image_grid_thw[i][2] // spatial_merge_size, - ) - t_indices = torch.arange(t, device=device).repeat_interleave(h * w) - h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) - w_indices = torch.arange(w, device=device).repeat(t * h) - image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) - - # offset by the position of the last vision segment - im = image_position_ids + vision_segment_lengths[i] - llm_pos_ids_list.append(im) - - # create position ids for each text segment - text_ranges = [ - torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) - + text_segment_lengths[i] - for i, seq_len in enumerate(text_lengths_between_vision) - ] - - full_llm_pos_ids_list = [ - item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist - ] - # import ipdb - - # ipdb.set_trace() - max_s = full_llm_pos_ids_list[-1].max() + 1 - final_text_len = input_ids_len - vision_ends[-1] - if final_text_len > 0: - m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) - full_llm_pos_ids_list.append(m + max_s) - - position_ids = ( - torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) - ) - return position_ids - - def post_process_outputs(self, logits, lm_head_indices): - return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0) - - def pre_process_inputs(self, **kwargs): - input_ids = kwargs["input_ids"] - position_ids = kwargs["position_ids"] - - input_ids = input_ids.unsqueeze(0) - position_ids = position_ids.transpose(0, 1).unsqueeze(1) - return {"input_ids": input_ids, "position_ids": position_ids} - - -class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): - def get_attention_mask(self, input_ids, cu_seqlen_prefill): - device = input_ids.device - dtype = self.dtype - min_dtype = torch.finfo(dtype).min - - lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist() - batch_size = len(lengths) - - sequence_length = max(lengths) - target_length = sequence_length - # Create the padding mask from the computed lengths. - # pad_mask: [batch, sequence_length] where True indicates valid tokens. - seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) - lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1) - pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length] - - # Build the base causal mask (for non-image tokens): - causal_mask = torch.tril( - torch.ones( - (sequence_length, sequence_length), dtype=torch.bool, device=device - ) - ) - base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze( - 1 - ) # [batch, sequence_length, sequence_length] - base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint - - image_token_mask = (input_ids == self.config.image_token_index).to( - input_ids.device - ) - - image_token_mask = torch.nn.utils.rnn.pad_sequence( - torch.split(image_token_mask, lengths), batch_first=True, padding_value=0 - ) - bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze( - 1 - ) - - # Combine the causal base mask and the bidirectional mask. - combined_mask = torch.logical_or( - base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1) - ).to(device) - # combined_mask now has shape [batch, 1, sequence_length, sequence_length] - - full_attention_mask = torch.zeros( - (batch_size, 1, sequence_length, target_length), - device=device, - dtype=torch.bool, - ) - full_attention_mask[:, :, :, :sequence_length] = combined_mask - - final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) - - return final_attention_mask - - def pre_process_inputs(self, **kwargs): - input_ids = kwargs["input_ids"] - position_ids = kwargs["position_ids"] - cu_seqlen_prefill = kwargs["cu_seqlen_prefill"] - - inputs = { - "input_ids": input_ids.unsqueeze(0), - "position_ids": position_ids.unsqueeze(0), - } - - if cu_seqlen_prefill is not None: - attention_mask = self.get_attention_mask( - input_ids.squeeze(0), cu_seqlen_prefill - ) - inputs["attention_mask"] = attention_mask - inputs["use_sdpa"] = True - - return inputs - - -class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): - def pre_process_inputs(self, **kwargs): - input_ids = kwargs["input_ids"] - position_ids = kwargs["position_ids"] - seqlen = kwargs["seqlen"] - block_tables = kwargs["block_tables"] - - inputs = super().pre_process_inputs(**kwargs) - inputs["cache_position"] = position_ids - inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) - # from loguru import logger - - # logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}") - cu_seqlen_k = seqlen.cu_seqlen_k - cu_seqlen_q = seqlen.cu_seqlen_q - seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] - ( - seqlens_q_local_np, - virt_q_cu_seqlens_np, - virt_k_seqlens_np, - virt_block_table, - ) = make_local_attention_virtual_batches( - self.model.config.text_config.attention_chunk_size, - cu_seqlen_q.cpu().numpy(), - seq_lens_np.cpu().numpy(), - block_tables, - BLOCK_SIZE, - ) - local_seqlen = Seqlen( - input_lengths=torch.from_numpy(virt_k_seqlens_np).to( - input_ids.device, non_blocking=True - ), - cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to( - input_ids.device, non_blocking=True - ), - cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to( - input_ids.device, non_blocking=True - ), - max_q=int(seqlens_q_local_np.max()), - max_k=int(virt_k_seqlens_np.max()), - ) - - inputs["local_seqlen"] = local_seqlen - inputs["local_block_tables"] = virt_block_table - - return inputs diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index a7beb68b..461bbc2d 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -12,8 +12,9 @@ from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache -from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE import torch.nn.functional as F +import numpy as np tracer = trace.get_tracer(__name__) @@ -27,6 +28,126 @@ REPLICATED_ATTENTION_MODELS = [ ] +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +# Adapted from: https://github.com/vllm-project/vllm/blob/e1a2c699dda82199e88e433c144eae66f3b31878/vllm/v1/attention/backends/flash_attn.py +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + page_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0)).astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + if ATTENTION == "flashdecoding": + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) + + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // page_size + assert attn_chunk_size % page_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not " + f"divisible by page_size {page_size}" + ) + pages_per_local_batch = attn_chunk_size // page_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming page_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices = np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch), + ) + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) + block_table_local = block_table[batch_indices, block_indices].view( + virtual_batches, -1 + ) + else: + block_table_local = block_table + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, block_table_local + + # # Qwen2VL # transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[ # "tgi" @@ -51,8 +172,14 @@ def tgi_flash_attention_forward( sliding_window: Optional[int] = None, softcap: Optional[float] = None, use_sdpa: Optional[bool] = False, + local_seqlen: Optional[Seqlen] = None, + local_block_tables: Optional[torch.Tensor] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): + if hasattr(module, "use_rope") and module.use_rope: + seqlen = local_seqlen + block_tables = local_block_tables + kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) @@ -313,7 +440,9 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): def get_position_ids(self, input_ids, image_grid_thw, position_ids): return position_ids - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] return { "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), @@ -372,7 +501,17 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, + seqlen=seqlen, + block_tables=block_tables, ) + + if cu_seqlen_prefill is not None: + from loguru import logger + + logger.info( + f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}" + ) + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( input_ids=inputs["input_ids"], @@ -396,6 +535,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): attention_mask=inputs.get("attention_mask", None), use_sdpa=inputs.get("use_sdpa", False), cache_position=inputs.get("cache_position", None), + local_seqlen=inputs.get("local_seqlen", None), + local_block_tables=inputs.get("local_block_tables", None), ).logits logits = self.post_process_outputs(logits, lm_head_indices) @@ -480,7 +621,10 @@ class TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM): def post_process_outputs(self, logits, lm_head_indices): return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0) - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + input_ids = input_ids.unsqueeze(0) position_ids = position_ids.transpose(0, 1).unsqueeze(1) return {"input_ids": input_ids, "position_ids": position_ids} @@ -542,7 +686,11 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): return final_attention_mask - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + cu_seqlen_prefill = kwargs["cu_seqlen_prefill"] + inputs = { "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), @@ -559,8 +707,48 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): - def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill): - inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill) + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + seqlen = kwargs["seqlen"] + block_tables = kwargs["block_tables"] + + inputs = super().pre_process_inputs(**kwargs) inputs["cache_position"] = position_ids inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) + # from loguru import logger + + # logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}") + cu_seqlen_k = seqlen.cu_seqlen_k + cu_seqlen_q = seqlen.cu_seqlen_q + seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] + ( + seqlens_q_local_np, + virt_q_cu_seqlens_np, + virt_k_seqlens_np, + virt_block_table, + ) = make_local_attention_virtual_batches( + self.model.config.text_config.attention_chunk_size, + cu_seqlen_q.cpu().numpy(), + seq_lens_np.cpu().numpy(), + block_tables, + BLOCK_SIZE, + ) + local_seqlen = Seqlen( + input_lengths=torch.from_numpy(virt_k_seqlens_np).to( + input_ids.device, non_blocking=True + ), + cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to( + input_ids.device, non_blocking=True + ), + cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to( + input_ids.device, non_blocking=True + ), + max_q=int(seqlens_q_local_np.max()), + max_k=int(virt_k_seqlens_np.max()), + ) + + inputs["local_seqlen"] = local_seqlen + inputs["local_block_tables"] = virt_block_table + return inputs From d2f8caff2b834a7b3962f4d44fd6b19ebc26b3ec Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 11 Apr 2025 15:05:28 +0000 Subject: [PATCH 3/5] support cuda graphs --- .../models/transformers_flash_vlm.py | 518 ++++++++++++++++-- 1 file changed, 474 insertions(+), 44 deletions(-) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 461bbc2d..b20eae62 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional +from typing import List, Optional, Tuple, Dict import torch from opentelemetry import trace @@ -12,7 +12,8 @@ from text_generation_server.utils import initialize_torch_distributed from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache -from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE +from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE, MEM_POOL +from text_generation_server.models.metadata_kernels import block_tables_to_ragged import torch.nn.functional as F import numpy as np @@ -172,13 +173,13 @@ def tgi_flash_attention_forward( sliding_window: Optional[int] = None, softcap: Optional[float] = None, use_sdpa: Optional[bool] = False, - local_seqlen: Optional[Seqlen] = None, - local_block_tables: Optional[torch.Tensor] = None, + seqlen_local: Optional[Seqlen] = None, + block_tables_local: Optional[torch.Tensor] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): if hasattr(module, "use_rope") and module.use_rope: - seqlen = local_seqlen - block_tables = local_block_tables + seqlen = seqlen_local + block_tables = block_tables_local kv_cache = kv_cache[module.layer_idx] query_states = query_states.transpose(1, 2).squeeze(dim=0) @@ -493,7 +494,10 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): image_grid_thw: Optional[torch.LongTensor] = None, pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + seqlen_local: Optional[Seqlen] = None, + block_tables_local: Optional[torch.Tensor] = None, ): + # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 @@ -505,13 +509,6 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): block_tables=block_tables, ) - if cu_seqlen_prefill is not None: - from loguru import logger - - logger.info( - f"input_ids: {input_ids.shape}, position_ids:{inputs.get('local_seqlen', None)}" - ) - # This is equivalent to `self.model.forward`, see the monkey patch in __init__ logits = self.model.original_forward( input_ids=inputs["input_ids"], @@ -535,8 +532,8 @@ class TransformersFlashVlmCausalLM(VlmCausalLM): attention_mask=inputs.get("attention_mask", None), use_sdpa=inputs.get("use_sdpa", False), cache_position=inputs.get("cache_position", None), - local_seqlen=inputs.get("local_seqlen", None), - local_block_tables=inputs.get("local_block_tables", None), + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, ).logits logits = self.post_process_outputs(logits, lm_head_indices) @@ -707,48 +704,481 @@ class TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM): class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): - def pre_process_inputs(self, **kwargs): - input_ids = kwargs["input_ids"] - position_ids = kwargs["position_ids"] - seqlen = kwargs["seqlen"] - block_tables = kwargs["block_tables"] + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + if max_bs is None: + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + config = getattr(self.model, "config", None) + rope_scaling = getattr(config, "rope_scaling", None) if config else None + if ( # mrope have position_ids per section, if so repeat n times + isinstance(rope_scaling, dict) and rope_scaling["rope_type"] == "mrope" + ): + n_sections = len(self.model.config.rope_scaling["mrope_section"]) + position_ids = position_ids.unsqueeze(1).repeat(1, n_sections) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) - inputs = super().pre_process_inputs(**kwargs) - inputs["cache_position"] = position_ids - inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) - # from loguru import logger + cu_seqlen_q = torch.arange( + input_lengths_tensor.shape[0] + 1, + device=self.device, + dtype=torch.int32, + ) + + ( + input_lengths_tensor_local, + cache_lengths_tensor_local, + seqlens_q_local, + max_q, + max_k, + block_tables_local, + ) = self.get_chunked_attention_seqlen( + cu_seqlen_q, + input_lengths_tensor, + block_tables, + ) + self.max_k_local = max_k + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] + + input_lengths_tensor_local = self.cuda_graphs[max_bs][ + "input_lengths_local" + ][:bs] + cache_lengths_tensor_local = self.cuda_graphs[max_bs][ + "cache_lengths_local" + ][:bs] + seqlens_q_local = self.cuda_graphs[max_bs]["seqlens_q_local"][:bs] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, + "input_lengths_local": input_lengths_tensor_local, + "cache_lengths_local": cache_lengths_tensor_local, + "seqlens_q_local": seqlens_q_local, + "block_tables_local": block_tables_local, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + # cu_seqlens_q_local = F.pad( + # torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0 + # ).to(torch.int32) + seqlen_local = Seqlen( + input_lengths=input_lengths_tensor_local, + cache_lengths=cache_lengths_tensor_local, + cu_seqlen_q=None, + max_q=1, + max_k=input_lengths_tensor_local.max(), + ) + self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, + ) + del seqlen + del seqlen_local + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + # cu_seqlens_q_local = F.pad( + # torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0 + # ).to(torch.int32) + seqlen_local = Seqlen( + input_lengths=input_lengths_tensor_local, + cache_lengths=cache_lengths_tensor_local, + cu_seqlen_q=None, + max_q=1, + max_k=input_lengths_tensor_local.max(), + ) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + torch.cuda.synchronize() + + def get_chunked_attention_seqlen( + self, + cu_seqlen_q, + seq_lens_np, + block_tables, + ): + attention_chunk_size = self.model.config.text_config.attention_chunk_size + # seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] - # logger.info(f"input_ids: {input_ids.shape}, position_ids: {position_ids.shape}") - cu_seqlen_k = seqlen.cu_seqlen_k - cu_seqlen_q = seqlen.cu_seqlen_q - seq_lens_np = cu_seqlen_k[1:] - cu_seqlen_k[:-1] ( seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, virt_block_table, ) = make_local_attention_virtual_batches( - self.model.config.text_config.attention_chunk_size, - cu_seqlen_q.cpu().numpy(), + attention_chunk_size, + ( + cu_seqlen_q.cpu().numpy() + if isinstance(cu_seqlen_q, torch.Tensor) + else cu_seqlen_q + ), seq_lens_np.cpu().numpy(), block_tables, BLOCK_SIZE, ) - local_seqlen = Seqlen( - input_lengths=torch.from_numpy(virt_k_seqlens_np).to( - input_ids.device, non_blocking=True - ), - cache_lengths=torch.zeros(virt_k_seqlens_np.shape).to( - input_ids.device, non_blocking=True - ), - cu_seqlen_q=torch.from_numpy(virt_q_cu_seqlens_np).to( - input_ids.device, non_blocking=True - ), - max_q=int(seqlens_q_local_np.max()), - max_k=int(virt_k_seqlens_np.max()), + + input_lengths = torch.from_numpy(virt_k_seqlens_np).to( + cu_seqlen_q.device, non_blocking=True + ) + cache_lengths = torch.zeros(virt_k_seqlens_np.shape).to( + cu_seqlen_q.device, non_blocking=True + ) + seqlens_q_local = torch.from_numpy(seqlens_q_local_np).to( + cu_seqlen_q.device, non_blocking=True ) - inputs["local_seqlen"] = local_seqlen - inputs["local_block_tables"] = virt_block_table + max_q = int(seqlens_q_local_np.max()) + max_k = int(virt_k_seqlens_np.max()) + + return ( + input_lengths, + cache_lengths, + seqlens_q_local, + max_q, + max_k, + virt_block_table, + ) + + def pre_process_inputs(self, **kwargs): + input_ids = kwargs["input_ids"] + position_ids = kwargs["position_ids"] + + inputs = super().pre_process_inputs(**kwargs) + inputs["cache_position"] = position_ids + inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device) return inputs + + def forward( + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + cache_lengths_tensor = ( + batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: + if position_ids.dim() == 1 and batch.prefilling: + position_ids = self.model.get_position_ids( + input_ids, batch.image_grid_thw + ) + batch.position_ids = position_ids + + # Try to find an associated cuda graph + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + + cu_seqlen_q = ( + cu_seqlen_prefill + if cu_seqlen_prefill is not None + else torch.arange( + input_lengths.shape[0] + 1, dtype=torch.int32, device=input_ids.device + ) + ) + ( + input_lengths_tensor_local, + cache_lengths_tensor_local, + seqlens_q_local, + max_q, + max_k, + block_tables_local, + ) = self.get_chunked_attention_seqlen( + cu_seqlen_q=cu_seqlen_q, + seq_lens_np=input_lengths + cache_lengths_tensor, + block_tables=block_tables, + ) + + if cu_seqlen_prefill is not None or cuda_graph is None: + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + raise RuntimeError("Flashinfer for LLama4 is not supported yet") + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) + + cu_seqlens_q_local = F.pad( + torch.cumsum(seqlens_q_local, dim=0), (1, 0), value=0 + ).to(torch.int32) + seqlen_local = Seqlen( + input_lengths=input_lengths_tensor_local, + cache_lengths=cache_lengths_tensor_local, + cu_seqlen_q=cu_seqlens_q_local, + max_q=max_q, + max_k=max_k, + ) + + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, + seqlen_local=seqlen_local, + block_tables_local=block_tables_local, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + raise RuntimeError("Flashinfer for LLama4 is not supported yet") + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + + cuda_graph["block_tables_local"][ + : block_tables_local.shape[0], : block_tables_local.shape[1] + ] = block_tables_local + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor + cuda_graph["input_lengths_local"].zero_() + cuda_graph["input_lengths_local"][ + : input_lengths_tensor_local.shape[0] + ] = input_lengths_tensor_local + cuda_graph["cache_lengths_local"].zero_() + cuda_graph["cache_lengths_local"][ + : cache_lengths_tensor_local.shape[0] + ] = cache_lengths_tensor_local + cuda_graph["seqlens_q_local"].zero_() + cuda_graph["seqlens_q_local"][: seqlens_q_local.shape[0]] = seqlens_q_local + + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits From a7353c35e8af9564c218929ab35d0e341211d41d Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 11 Apr 2025 15:10:19 +0000 Subject: [PATCH 4/5] fix bt --- server/text_generation_server/models/transformers_flash_vlm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index b20eae62..c1852a4f 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -769,6 +769,7 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] else: block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + block_tables_local = self.cuda_graphs[max_bs]["block_tables_local"][:bs] slots = self.cuda_graphs[max_bs]["slots"][:bs] input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] From 2a10a28d08a0793e4bde36aa01a2a043e676dacf Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 11 Apr 2025 15:24:12 +0000 Subject: [PATCH 5/5] force attn to flashdecoding --- launcher/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index acff8573..2e22c100 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -158,7 +158,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> prefix_caching = Some("0".to_string()); } match config.model_type.as_deref() { - Some("falcon") | Some("deepseek_v2") => { + Some("falcon") | Some("deepseek_v2") | Some("llama4") => { // Required because gemma2 needs bfloat16 which is not supported by // flashinfer ? if attention.is_none() {