mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Support flashinfer for Gemma3 prefill
Gemma3 uses bidirectional attention for images. Flashinfer supports custom masks. Hook up the mask with flashinfer, so that we do not have to use the slower SDPA implementation for prefills with images.
This commit is contained in:
parent
a9b26b221a
commit
6652d6e6e0
@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
|
|||||||
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
|
custom_mask: Optional[torch.Tensor],
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
|
|||||||
paged_kv_indptr=indptr,
|
paged_kv_indptr=indptr,
|
||||||
paged_kv_indices=block_tables,
|
paged_kv_indices=block_tables,
|
||||||
paged_kv_last_page_len=last_page_len,
|
paged_kv_last_page_len=last_page_len,
|
||||||
|
custom_mask=custom_mask,
|
||||||
num_qo_heads=num_heads,
|
num_qo_heads=num_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
|
@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
if attention_mask is None:
|
if attention_mask is None or ATTENTION == "flashinfer":
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query=query,
|
query=query,
|
||||||
@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_attention_mask(
|
def get_attention_mask(
|
||||||
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bool_mask: bool = False,
|
||||||
):
|
):
|
||||||
|
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||||
|
input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||||
|
|
||||||
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
if bool_mask:
|
||||||
|
return full_attention_mask
|
||||||
return final_attention_mask
|
else:
|
||||||
|
return torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
attention_mask = self.get_attention_mask(
|
attention_mask = self.get_attention_mask(
|
||||||
input_ids,
|
input_ids,
|
||||||
max_s,
|
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
inputs_embeds.dtype,
|
inputs_embeds.dtype,
|
||||||
image_token_mask,
|
|
||||||
)
|
)
|
||||||
# Use flash attention for text-only input
|
# Use flash attention for text-only input
|
||||||
# else:
|
# else:
|
||||||
|
@ -2434,6 +2434,7 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths_tensor: torch.Tensor,
|
input_lengths_tensor: torch.Tensor,
|
||||||
cache_lengths_tensor: torch.Tensor,
|
cache_lengths_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
@ -2450,6 +2451,7 @@ class FlashCausalLM(Model):
|
|||||||
),
|
),
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
|
custom_mask=attention_mask,
|
||||||
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
@ -5,6 +5,7 @@ from io import BytesIO
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from torch.nn import attention
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
@ -485,6 +486,14 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
|
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||||
|
# Get the mask, needed for flashinfer.
|
||||||
|
attention_mask = self.model.get_attention_mask(
|
||||||
|
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||||
|
).reshape(-1)
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
@ -508,6 +517,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
cache_lengths_tensor=cache_lengths_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
attention_mask=attention_mask,
|
||||||
):
|
):
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
Loading…
Reference in New Issue
Block a user