From 6652d6e6e0b864f5f55f8f7e053462955396ef77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 11 Apr 2025 15:58:57 +0000 Subject: [PATCH] Support flashinfer for Gemma3 prefill Gemma3 uses bidirectional attention for images. Flashinfer supports custom masks. Hook up the mask with flashinfer, so that we do not have to use the slower SDPA implementation for prefills with images. --- .../layers/attention/flashinfer.py | 2 ++ .../custom_modeling/flash_gemma3_modeling.py | 22 +++++++++++++------ .../models/flash_causal_lm.py | 2 ++ .../models/vlm_causal_lm.py | 10 +++++++++ 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 9479b606..f78475d5 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state( state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, block_tables: torch.Tensor, cu_seqlens: torch.Tensor, + custom_mask: Optional[torch.Tensor], input_lengths: torch.Tensor, num_heads: int, num_kv_heads: int, @@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state( paged_kv_indptr=indptr, paged_kv_indices=block_tables, paged_kv_last_page_len=last_page_len, + custom_mask=custom_mask, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 70fe9a3d..58afd643 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.models.globals import ATTENTION from text_generation_server.utils.weights import UnquantizedWeight from transformers.activations import ACT2FN from text_generation_server.layers.attention import ( @@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if attention_mask is None: + if attention_mask is None or ATTENTION == "flashinfer": # flash attention attn_output = attention( query=query, @@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module): ) def get_attention_mask( - self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask + self, + input_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + dtype: torch.dtype, + bool_mask: bool = False, ): + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + device = input_ids.device min_dtype = torch.finfo(dtype).min @@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module): ) full_attention_mask[:, :, :, :sequence_length] = combined_mask - final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) - - return final_attention_mask + if bool_mask: + return full_attention_mask + else: + return torch.where(full_attention_mask, 0, min_dtype).to(device) def forward( self, @@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module): ) attention_mask = self.get_attention_mask( input_ids, - max_s, cu_seqlen_prefill, inputs_embeds.dtype, - image_token_mask, ) # Use flash attention for text-only input # else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7c5a374..a28ef381 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2434,6 +2434,7 @@ class FlashCausalLM(Model): input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() @@ -2450,6 +2451,7 @@ class FlashCausalLM(Model): ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, + custom_mask=attention_mask, input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5f8eb906..97e954ed 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -5,6 +5,7 @@ from io import BytesIO from opentelemetry import trace from typing import Iterable, Optional, Tuple, List, Type, Dict +from torch.nn import attention from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 @@ -485,6 +486,14 @@ class VlmCausalLM(FlashCausalLM): ) batch.position_ids = position_ids + if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: + # Get the mask, needed for flashinfer. + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ).reshape(-1) + else: + attention_mask = None + # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -508,6 +517,7 @@ class VlmCausalLM(FlashCausalLM): cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, + attention_mask=attention_mask, ): seqlen = Seqlen( input_lengths=input_lengths,