Support flashinfer for Gemma3 prefill (#3167)

* launcher: ensure correct detection of Gemma 3 head size

* Support flashinfer for Gemma3 prefill

Gemma3 uses bidirectional attention for images. Flashinfer
supports custom masks. Hook up the mask with flashinfer, so that we do
not have to use the slower SDPA implementation for prefills with images.

* Update Gemma3 test outputs

* Fixed unused import
This commit is contained in:
Daniël de Kok 2025-04-17 18:07:41 +02:00 committed by GitHub
parent 4645678ff0
commit 84ab88d843
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 141 additions and 109 deletions

View File

@ -8,61 +8,61 @@
"tokens": [
{
"id": 1331,
"logprob": -0.34960938,
"logprob": -0.31835938,
"special": false,
"text": " people"
},
{
"id": 8390,
"logprob": -0.14746094,
"logprob": -0.1484375,
"special": false,
"text": " died"
},
{
"id": 528,
"logprob": -1.2265625,
"logprob": -1.1171875,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -0.47070312,
"logprob": -0.45898438,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.5859375,
"logprob": -0.55859375,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": -0.0027770996,
"logprob": -0.0026397705,
"special": false,
"text": " States"
},
{
"id": 236761,
"logprob": -0.34765625,
"logprob": -0.38085938,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.0859375,
"logprob": -0.07421875,
"special": false,
"text": "\n\n"
},
{
"id": 818,
"logprob": -1.1640625,
"logprob": -1.0859375,
"special": false,
"text": "The"
},
{
"id": 6816,
"logprob": -1.890625,
"logprob": -1.75,
"special": false,
"text": " generally"
},
@ -74,7 +74,7 @@
},
{
"id": 10967,
"logprob": -0.90625,
"logprob": -0.9609375,
"special": false,
"text": " estimate"
},
@ -86,43 +86,43 @@
},
{
"id": 600,
"logprob": -0.65234375,
"logprob": -0.703125,
"special": false,
"text": " that"
},
{
"id": 236743,
"logprob": -1.2109375,
"logprob": -1.171875,
"special": false,
"text": " "
},
{
"id": 236825,
"logprob": -0.00088119507,
"logprob": -0.0009918213,
"special": false,
"text": "6"
},
{
"id": 236832,
"logprob": -6.580353e-05,
"logprob": -6.389618e-05,
"special": false,
"text": "7"
},
{
"id": 236810,
"logprob": -5.2690506e-05,
"logprob": -4.7445297e-05,
"special": false,
"text": "5"
},
{
"id": 236764,
"logprob": -0.0001745224,
"logprob": -0.00017929077,
"special": false,
"text": ","
},
{
"id": 236771,
"logprob": -1.180172e-05,
"logprob": -1.4901161e-05,
"special": false,
"text": "0"
},
@ -140,7 +140,7 @@
},
{
"id": 1331,
"logprob": -0.44921875,
"logprob": -0.45898438,
"special": false,
"text": " people"
},
@ -158,49 +158,49 @@
},
{
"id": 506,
"logprob": -0.00034713745,
"logprob": -0.00032615662,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.028564453,
"logprob": -0.029785156,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": -0.00012207031,
"logprob": -0.00012302399,
"special": false,
"text": " States"
},
{
"id": 236761,
"logprob": -1.15625,
"logprob": -1.1796875,
"special": false,
"text": "."
},
{
"id": 3153,
"logprob": -0.103027344,
"logprob": -0.09667969,
"special": false,
"text": " However"
},
{
"id": 236764,
"logprob": -0.009155273,
"logprob": -0.009094238,
"special": false,
"text": ","
},
{
"id": 1070,
"logprob": -0.92578125,
"logprob": -0.91015625,
"special": false,
"text": " some"
},
{
"id": 61806,
"logprob": -0.91796875,
"logprob": -0.859375,
"special": false,
"text": " historians"
},
@ -218,79 +218,79 @@
},
{
"id": 5396,
"logprob": -0.8046875,
"logprob": -0.765625,
"special": false,
"text": " actual"
},
{
"id": 1548,
"logprob": -0.04321289,
"logprob": -0.048339844,
"special": false,
"text": " number"
},
{
"id": 1451,
"logprob": -0.66015625,
"logprob": -0.65625,
"special": false,
"text": " could"
},
{
"id": 577,
"logprob": -0.091308594,
"logprob": -0.09082031,
"special": false,
"text": " be"
},
{
"id": 618,
"logprob": -0.57421875,
"logprob": -0.625,
"special": false,
"text": " as"
},
{
"id": 1494,
"logprob": -0.00036239624,
"logprob": -0.00037193298,
"special": false,
"text": " high"
},
{
"id": 618,
"logprob": -0.0001335144,
"logprob": -0.0001296997,
"special": false,
"text": " as"
},
{
"id": 236743,
"logprob": -0.0009689331,
"logprob": -0.00093460083,
"special": false,
"text": " "
},
{
"id": 236770,
"logprob": -0.26367188,
"logprob": -0.21289062,
"special": false,
"text": "1"
},
{
"id": 236771,
"logprob": -0.17773438,
"logprob": -0.16796875,
"special": false,
"text": "0"
},
{
"id": 3625,
"logprob": -0.012084961,
"logprob": -0.0126953125,
"special": false,
"text": " million"
},
{
"id": 236761,
"logprob": -0.21289062,
"logprob": -0.22460938,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.37304688,
"logprob": -0.3984375,
"special": false,
"text": "\n\n"
},
@ -302,13 +302,13 @@
},
{
"id": 1006,
"logprob": -1.3203125,
"logprob": -1.359375,
"special": false,
"text": " am"
},
{
"id": 3182,
"logprob": -1.078125,
"logprob": -1.0859375,
"special": false,
"text": " looking"
},
@ -320,85 +320,85 @@
},
{
"id": 919,
"logprob": -1.25,
"logprob": -1.2578125,
"special": false,
"text": " more"
},
{
"id": 1938,
"logprob": -1.2421875,
"logprob": -1.3046875,
"special": false,
"text": " information"
},
{
"id": 580,
"logprob": -0.7734375,
"logprob": -0.7421875,
"special": false,
"text": " on"
},
{
"id": 672,
"logprob": -0.73046875,
"logprob": -0.78125,
"special": false,
"text": " this"
},
{
"id": 59725,
"logprob": -0.75,
"logprob": -0.7109375,
"special": false,
"text": " discrepancy"
},
{
"id": 532,
"logprob": -0.83984375,
"logprob": -0.8046875,
"special": false,
"text": " and"
},
{
"id": 506,
"logprob": -0.7109375,
"logprob": -0.71484375,
"special": false,
"text": " the"
},
{
"id": 5872,
"logprob": -1.2734375,
"logprob": -1.1640625,
"special": false,
"text": " factors"
},
{
"id": 600,
"logprob": -0.22851562,
"logprob": -0.20410156,
"special": false,
"text": " that"
},
{
"id": 19263,
"logprob": -1.1640625,
"logprob": -1.1484375,
"special": false,
"text": " contributed"
},
{
"id": 531,
"logprob": -0.0010757446,
"logprob": -0.000957489,
"special": false,
"text": " to"
},
{
"id": 506,
"logprob": -0.18945312,
"logprob": -0.19921875,
"special": false,
"text": " the"
},
{
"id": 5777,
"logprob": -1.2734375,
"logprob": -1.171875,
"special": false,
"text": " wide"
},
{
"id": 2644,
"logprob": -0.01940918,
"logprob": -0.020141602,
"special": false,
"text": " range"
},
@ -410,31 +410,31 @@
},
{
"id": 14287,
"logprob": -0.032470703,
"logprob": -0.03564453,
"special": false,
"text": " estimates"
},
{
"id": 236761,
"logprob": -0.010375977,
"logprob": -0.010620117,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.06591797,
"logprob": -0.060302734,
"special": false,
"text": "\n\n"
},
{
"id": 8291,
"logprob": -0.8046875,
"logprob": -0.7421875,
"special": false,
"text": "Here"
},
{
"id": 236789,
"logprob": -0.23828125,
"logprob": -0.24023438,
"special": false,
"text": "'"
},
@ -446,55 +446,55 @@
},
{
"id": 496,
"logprob": -0.17480469,
"logprob": -0.16992188,
"special": false,
"text": " a"
},
{
"id": 25890,
"logprob": -0.087402344,
"logprob": -0.06933594,
"special": false,
"text": " breakdown"
},
{
"id": 529,
"logprob": -0.0021209717,
"logprob": -0.002243042,
"special": false,
"text": " of"
},
{
"id": 506,
"logprob": -0.19140625,
"logprob": -0.18554688,
"special": false,
"text": " the"
},
{
"id": 5872,
"logprob": -1.0078125,
"logprob": -0.9921875,
"special": false,
"text": " factors"
},
{
"id": 20894,
"logprob": -0.26367188,
"logprob": -0.25976562,
"special": false,
"text": " contributing"
},
{
"id": 531,
"logprob": -9.250641e-05,
"logprob": -8.440018e-05,
"special": false,
"text": " to"
},
{
"id": 506,
"logprob": -0.008666992,
"logprob": -0.009765625,
"special": false,
"text": " the"
},
{
"id": 5777,
"logprob": -0.6171875,
"logprob": -0.67578125,
"special": false,
"text": " wide"
},
@ -506,31 +506,31 @@
},
{
"id": 529,
"logprob": -0.016723633,
"logprob": -0.014831543,
"special": false,
"text": " of"
},
{
"id": 14287,
"logprob": -0.011352539,
"logprob": -0.012329102,
"special": false,
"text": " estimates"
},
{
"id": 573,
"logprob": -0.30664062,
"logprob": -0.3125,
"special": false,
"text": " for"
},
{
"id": 506,
"logprob": -0.21386719,
"logprob": -0.21484375,
"special": false,
"text": " the"
},
{
"id": 236743,
"logprob": -0.35351562,
"logprob": -0.43359375,
"special": false,
"text": " "
},
@ -560,43 +560,43 @@
},
{
"id": 7745,
"logprob": -0.70703125,
"logprob": -0.703125,
"special": false,
"text": " flu"
},
{
"id": 10248,
"logprob": -0.015258789,
"logprob": -0.013427734,
"special": false,
"text": " pandemic"
},
{
"id": 4355,
"logprob": -0.83203125,
"logprob": -0.6953125,
"special": false,
"text": " death"
},
{
"id": 25363,
"logprob": -7.43866e-05,
"logprob": -6.771088e-05,
"special": false,
"text": " toll"
},
{
"id": 528,
"logprob": -0.08496094,
"logprob": -0.076171875,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -6.67572e-06,
"logprob": -7.2717667e-06,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.0059509277,
"logprob": -0.0052490234,
"special": false,
"text": " United"
},

View File

@ -1,11 +1,11 @@
{
"choices": [
{
"finish_reason": "stop",
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?",
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1741965892,
"created": 1744396706,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 98,
"completion_tokens": 100,
"prompt_tokens": 277,
"total_tokens": 375
"total_tokens": 377
}
}

View File

@ -5,7 +5,7 @@
"index": 0,
"logprobs": null,
"message": {
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?",
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.",
"name": null,
"role": "assistant",
"tool_calls": null
@ -13,14 +13,14 @@
"usage": null
}
],
"created": 1741966313,
"created": 1744396703,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 67,
"completion_tokens": 78,
"prompt_tokens": 277,
"total_tokens": 344
"total_tokens": 355
}
}

View File

@ -13,11 +13,11 @@
"usage": null
}
],
"created": 1741964480,
"created": 1744396699,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 74,
"prompt_tokens": 275,

View File

@ -13,11 +13,11 @@
"usage": null
}
],
"created": 1741964477,
"created": 1744396697,
"id": "",
"model": "google/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.2.1-dev0-native",
"system_fingerprint": "3.2.3-dev0-native",
"usage": {
"completion_tokens": 75,
"prompt_tokens": 279,

View File

@ -260,11 +260,22 @@ struct Config {
impl Config {
fn get_head_dim(&self) -> Option<usize> {
self.head_dim.or_else(|| {
self.text_config
.as_ref()
.and_then(|text_config| text_config.head_dim)
})
if let Some(head_dim) = self.head_dim {
return Some(head_dim);
}
let text_config = self.text_config.as_ref()?;
if let Some(head_size) = text_config.head_dim {
return Some(head_size);
}
match self.model_type.as_deref() {
// We special-case gemma3 here, since we need flashinfer for
// handling bidirectional masks. And flashinfer can only be
// used when the head size is known.
Some("gemma3") => Some(256),
_ => None,
}
}
fn flop(&self) -> Option<u64> {

View File

@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
custom_mask: Optional[torch.Tensor],
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
custom_mask=custom_mask,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.models.globals import ATTENTION
from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
if attention_mask is None:
if attention_mask is None or ATTENTION == "flashinfer":
# flash attention
attn_output = attention(
query=query,
@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module):
)
def get_attention_mask(
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
self,
input_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
dtype: torch.dtype,
bool_mask: bool = False,
):
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
device = input_ids.device
min_dtype = torch.finfo(dtype).min
@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module):
)
full_attention_mask[:, :, :, :sequence_length] = combined_mask
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
return final_attention_mask
if bool_mask:
return full_attention_mask
else:
return torch.where(full_attention_mask, 0, min_dtype).to(device)
def forward(
self,
@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module):
)
attention_mask = self.get_attention_mask(
input_ids,
max_s,
cu_seqlen_prefill,
inputs_embeds.dtype,
image_token_mask,
)
# Use flash attention for text-only input
# else:

View File

@ -2434,6 +2434,7 @@ class FlashCausalLM(Model):
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
state: Optional[Any] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> ContextManager:
if ATTENTION != "flashinfer":
return nullcontext()
@ -2450,6 +2451,7 @@ class FlashCausalLM(Model):
),
block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill,
custom_mask=attention_mask,
input_lengths=input_lengths_tensor + cache_lengths_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,

View File

@ -485,6 +485,14 @@ class VlmCausalLM(FlashCausalLM):
)
batch.position_ids = position_ids
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
# Get the mask, needed for flashinfer.
attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
).reshape(-1)
else:
attention_mask = None
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@ -508,6 +516,7 @@ class VlmCausalLM(FlashCausalLM):
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
cache_lengths_tensor=cache_lengths_tensor,
attention_mask=attention_mask,
):
seqlen = Seqlen(
input_lengths=input_lengths,