Support flashinfer for Gemma3 prefill

Gemma3 uses bidirectional attention for images. Flashinfer
supports custom masks. Hook up the mask with flashinfer, so that we do
not have to use the slower SDPA implementation for prefills with images.
This commit is contained in:
Daniël de Kok 2025-04-11 15:58:57 +00:00
parent a9b26b221a
commit 6652d6e6e0
4 changed files with 29 additions and 7 deletions

View File

@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
custom_mask: Optional[torch.Tensor],
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
paged_kv_indptr=indptr, paged_kv_indptr=indptr,
paged_kv_indices=block_tables, paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len, paged_kv_last_page_len=last_page_len,
custom_mask=custom_mask,
num_qo_heads=num_heads, num_qo_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_size, head_dim=head_size,

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.models.globals import ATTENTION
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if attention_mask is None: if attention_mask is None or ATTENTION == "flashinfer":
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query=query, query=query,
@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module):
) )
def get_attention_mask( def get_attention_mask(
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask self,
input_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
dtype: torch.dtype,
bool_mask: bool = False,
): ):
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
device = input_ids.device device = input_ids.device
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module):
) )
full_attention_mask[:, :, :, :sequence_length] = combined_mask full_attention_mask[:, :, :, :sequence_length] = combined_mask
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) if bool_mask:
return full_attention_mask
return final_attention_mask else:
return torch.where(full_attention_mask, 0, min_dtype).to(device)
def forward( def forward(
self, self,
@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module):
) )
attention_mask = self.get_attention_mask( attention_mask = self.get_attention_mask(
input_ids, input_ids,
max_s,
cu_seqlen_prefill, cu_seqlen_prefill,
inputs_embeds.dtype, inputs_embeds.dtype,
image_token_mask,
) )
# Use flash attention for text-only input # Use flash attention for text-only input
# else: # else:

View File

@ -2434,6 +2434,7 @@ class FlashCausalLM(Model):
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
return nullcontext() return nullcontext()
@ -2450,6 +2451,7 @@ class FlashCausalLM(Model):
), ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
custom_mask=attention_mask,
input_lengths=input_lengths_tensor + cache_lengths_tensor, input_lengths=input_lengths_tensor + cache_lengths_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,

View File

@ -5,6 +5,7 @@ from io import BytesIO
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from torch.nn import attention
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@ -485,6 +486,14 @@ class VlmCausalLM(FlashCausalLM):
) )
batch.position_ids = position_ids batch.position_ids = position_ids
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
# Get the mask, needed for flashinfer.
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
).reshape(-1)
else:
attention_mask = None
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -508,6 +517,7 @@ class VlmCausalLM(FlashCausalLM):
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
attention_mask=attention_mask,
): ):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,