diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 9479b606..f78475d5 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state( state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, block_tables: torch.Tensor, cu_seqlens: torch.Tensor, + custom_mask: Optional[torch.Tensor], input_lengths: torch.Tensor, num_heads: int, num_kv_heads: int, @@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state( paged_kv_indptr=indptr, paged_kv_indices=block_tables, paged_kv_last_page_len=last_page_len, + custom_mask=custom_mask, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 70fe9a3d..58afd643 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.models.globals import ATTENTION from text_generation_server.utils.weights import UnquantizedWeight from transformers.activations import ACT2FN from text_generation_server.layers.attention import ( @@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if attention_mask is None: + if attention_mask is None or ATTENTION == "flashinfer": # flash attention attn_output = attention( query=query, @@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module): ) def get_attention_mask( - self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask + self, + input_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + dtype: torch.dtype, + bool_mask: bool = False, ): + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + device = input_ids.device min_dtype = torch.finfo(dtype).min @@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module): ) full_attention_mask[:, :, :, :sequence_length] = combined_mask - final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) - - return final_attention_mask + if bool_mask: + return full_attention_mask + else: + return torch.where(full_attention_mask, 0, min_dtype).to(device) def forward( self, @@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module): ) attention_mask = self.get_attention_mask( input_ids, - max_s, cu_seqlen_prefill, inputs_embeds.dtype, - image_token_mask, ) # Use flash attention for text-only input # else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7c5a374..a28ef381 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2434,6 +2434,7 @@ class FlashCausalLM(Model): input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() @@ -2450,6 +2451,7 @@ class FlashCausalLM(Model): ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, + custom_mask=attention_mask, input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5f8eb906..97e954ed 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -5,6 +5,7 @@ from io import BytesIO from opentelemetry import trace from typing import Iterable, Optional, Tuple, List, Type, Dict +from torch.nn import attention from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 @@ -485,6 +486,14 @@ class VlmCausalLM(FlashCausalLM): ) batch.position_ids = position_ids + if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: + # Get the mask, needed for flashinfer. + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ).reshape(-1) + else: + attention_mask = None + # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -508,6 +517,7 @@ class VlmCausalLM(FlashCausalLM): cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, + attention_mask=attention_mask, ): seqlen = Seqlen( input_lengths=input_lengths,