diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json index 859544c89..be8b3882f 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json @@ -8,61 +8,61 @@ "tokens": [ { "id": 1331, - "logprob": -0.34960938, + "logprob": -0.31835938, "special": false, "text": " people" }, { "id": 8390, - "logprob": -0.14746094, + "logprob": -0.1484375, "special": false, "text": " died" }, { "id": 528, - "logprob": -1.2265625, + "logprob": -1.1171875, "special": false, "text": " in" }, { "id": 506, - "logprob": -0.47070312, + "logprob": -0.45898438, "special": false, "text": " the" }, { "id": 3640, - "logprob": -0.5859375, + "logprob": -0.55859375, "special": false, "text": " United" }, { "id": 4184, - "logprob": -0.0027770996, + "logprob": -0.0026397705, "special": false, "text": " States" }, { "id": 236761, - "logprob": -0.34765625, + "logprob": -0.38085938, "special": false, "text": "." }, { "id": 108, - "logprob": -0.0859375, + "logprob": -0.07421875, "special": false, "text": "\n\n" }, { "id": 818, - "logprob": -1.1640625, + "logprob": -1.0859375, "special": false, "text": "The" }, { "id": 6816, - "logprob": -1.890625, + "logprob": -1.75, "special": false, "text": " generally" }, @@ -74,7 +74,7 @@ }, { "id": 10967, - "logprob": -0.90625, + "logprob": -0.9609375, "special": false, "text": " estimate" }, @@ -86,43 +86,43 @@ }, { "id": 600, - "logprob": -0.65234375, + "logprob": -0.703125, "special": false, "text": " that" }, { "id": 236743, - "logprob": -1.2109375, + "logprob": -1.171875, "special": false, "text": " " }, { "id": 236825, - "logprob": -0.00088119507, + "logprob": -0.0009918213, "special": false, "text": "6" }, { "id": 236832, - "logprob": -6.580353e-05, + "logprob": -6.389618e-05, "special": false, "text": "7" }, { "id": 236810, - "logprob": -5.2690506e-05, + "logprob": -4.7445297e-05, "special": false, "text": "5" }, { "id": 236764, - "logprob": -0.0001745224, + "logprob": -0.00017929077, "special": false, "text": "," }, { "id": 236771, - "logprob": -1.180172e-05, + "logprob": -1.4901161e-05, "special": false, "text": "0" }, @@ -140,7 +140,7 @@ }, { "id": 1331, - "logprob": -0.44921875, + "logprob": -0.45898438, "special": false, "text": " people" }, @@ -158,49 +158,49 @@ }, { "id": 506, - "logprob": -0.00034713745, + "logprob": -0.00032615662, "special": false, "text": " the" }, { "id": 3640, - "logprob": -0.028564453, + "logprob": -0.029785156, "special": false, "text": " United" }, { "id": 4184, - "logprob": -0.00012207031, + "logprob": -0.00012302399, "special": false, "text": " States" }, { "id": 236761, - "logprob": -1.15625, + "logprob": -1.1796875, "special": false, "text": "." }, { "id": 3153, - "logprob": -0.103027344, + "logprob": -0.09667969, "special": false, "text": " However" }, { "id": 236764, - "logprob": -0.009155273, + "logprob": -0.009094238, "special": false, "text": "," }, { "id": 1070, - "logprob": -0.92578125, + "logprob": -0.91015625, "special": false, "text": " some" }, { "id": 61806, - "logprob": -0.91796875, + "logprob": -0.859375, "special": false, "text": " historians" }, @@ -218,79 +218,79 @@ }, { "id": 5396, - "logprob": -0.8046875, + "logprob": -0.765625, "special": false, "text": " actual" }, { "id": 1548, - "logprob": -0.04321289, + "logprob": -0.048339844, "special": false, "text": " number" }, { "id": 1451, - "logprob": -0.66015625, + "logprob": -0.65625, "special": false, "text": " could" }, { "id": 577, - "logprob": -0.091308594, + "logprob": -0.09082031, "special": false, "text": " be" }, { "id": 618, - "logprob": -0.57421875, + "logprob": -0.625, "special": false, "text": " as" }, { "id": 1494, - "logprob": -0.00036239624, + "logprob": -0.00037193298, "special": false, "text": " high" }, { "id": 618, - "logprob": -0.0001335144, + "logprob": -0.0001296997, "special": false, "text": " as" }, { "id": 236743, - "logprob": -0.0009689331, + "logprob": -0.00093460083, "special": false, "text": " " }, { "id": 236770, - "logprob": -0.26367188, + "logprob": -0.21289062, "special": false, "text": "1" }, { "id": 236771, - "logprob": -0.17773438, + "logprob": -0.16796875, "special": false, "text": "0" }, { "id": 3625, - "logprob": -0.012084961, + "logprob": -0.0126953125, "special": false, "text": " million" }, { "id": 236761, - "logprob": -0.21289062, + "logprob": -0.22460938, "special": false, "text": "." }, { "id": 108, - "logprob": -0.37304688, + "logprob": -0.3984375, "special": false, "text": "\n\n" }, @@ -302,13 +302,13 @@ }, { "id": 1006, - "logprob": -1.3203125, + "logprob": -1.359375, "special": false, "text": " am" }, { "id": 3182, - "logprob": -1.078125, + "logprob": -1.0859375, "special": false, "text": " looking" }, @@ -320,85 +320,85 @@ }, { "id": 919, - "logprob": -1.25, + "logprob": -1.2578125, "special": false, "text": " more" }, { "id": 1938, - "logprob": -1.2421875, + "logprob": -1.3046875, "special": false, "text": " information" }, { "id": 580, - "logprob": -0.7734375, + "logprob": -0.7421875, "special": false, "text": " on" }, { "id": 672, - "logprob": -0.73046875, + "logprob": -0.78125, "special": false, "text": " this" }, { "id": 59725, - "logprob": -0.75, + "logprob": -0.7109375, "special": false, "text": " discrepancy" }, { "id": 532, - "logprob": -0.83984375, + "logprob": -0.8046875, "special": false, "text": " and" }, { "id": 506, - "logprob": -0.7109375, + "logprob": -0.71484375, "special": false, "text": " the" }, { "id": 5872, - "logprob": -1.2734375, + "logprob": -1.1640625, "special": false, "text": " factors" }, { "id": 600, - "logprob": -0.22851562, + "logprob": -0.20410156, "special": false, "text": " that" }, { "id": 19263, - "logprob": -1.1640625, + "logprob": -1.1484375, "special": false, "text": " contributed" }, { "id": 531, - "logprob": -0.0010757446, + "logprob": -0.000957489, "special": false, "text": " to" }, { "id": 506, - "logprob": -0.18945312, + "logprob": -0.19921875, "special": false, "text": " the" }, { "id": 5777, - "logprob": -1.2734375, + "logprob": -1.171875, "special": false, "text": " wide" }, { "id": 2644, - "logprob": -0.01940918, + "logprob": -0.020141602, "special": false, "text": " range" }, @@ -410,31 +410,31 @@ }, { "id": 14287, - "logprob": -0.032470703, + "logprob": -0.03564453, "special": false, "text": " estimates" }, { "id": 236761, - "logprob": -0.010375977, + "logprob": -0.010620117, "special": false, "text": "." }, { "id": 108, - "logprob": -0.06591797, + "logprob": -0.060302734, "special": false, "text": "\n\n" }, { "id": 8291, - "logprob": -0.8046875, + "logprob": -0.7421875, "special": false, "text": "Here" }, { "id": 236789, - "logprob": -0.23828125, + "logprob": -0.24023438, "special": false, "text": "'" }, @@ -446,55 +446,55 @@ }, { "id": 496, - "logprob": -0.17480469, + "logprob": -0.16992188, "special": false, "text": " a" }, { "id": 25890, - "logprob": -0.087402344, + "logprob": -0.06933594, "special": false, "text": " breakdown" }, { "id": 529, - "logprob": -0.0021209717, + "logprob": -0.002243042, "special": false, "text": " of" }, { "id": 506, - "logprob": -0.19140625, + "logprob": -0.18554688, "special": false, "text": " the" }, { "id": 5872, - "logprob": -1.0078125, + "logprob": -0.9921875, "special": false, "text": " factors" }, { "id": 20894, - "logprob": -0.26367188, + "logprob": -0.25976562, "special": false, "text": " contributing" }, { "id": 531, - "logprob": -9.250641e-05, + "logprob": -8.440018e-05, "special": false, "text": " to" }, { "id": 506, - "logprob": -0.008666992, + "logprob": -0.009765625, "special": false, "text": " the" }, { "id": 5777, - "logprob": -0.6171875, + "logprob": -0.67578125, "special": false, "text": " wide" }, @@ -506,31 +506,31 @@ }, { "id": 529, - "logprob": -0.016723633, + "logprob": -0.014831543, "special": false, "text": " of" }, { "id": 14287, - "logprob": -0.011352539, + "logprob": -0.012329102, "special": false, "text": " estimates" }, { "id": 573, - "logprob": -0.30664062, + "logprob": -0.3125, "special": false, "text": " for" }, { "id": 506, - "logprob": -0.21386719, + "logprob": -0.21484375, "special": false, "text": " the" }, { "id": 236743, - "logprob": -0.35351562, + "logprob": -0.43359375, "special": false, "text": " " }, @@ -560,43 +560,43 @@ }, { "id": 7745, - "logprob": -0.70703125, + "logprob": -0.703125, "special": false, "text": " flu" }, { "id": 10248, - "logprob": -0.015258789, + "logprob": -0.013427734, "special": false, "text": " pandemic" }, { "id": 4355, - "logprob": -0.83203125, + "logprob": -0.6953125, "special": false, "text": " death" }, { "id": 25363, - "logprob": -7.43866e-05, + "logprob": -6.771088e-05, "special": false, "text": " toll" }, { "id": 528, - "logprob": -0.08496094, + "logprob": -0.076171875, "special": false, "text": " in" }, { "id": 506, - "logprob": -6.67572e-06, + "logprob": -7.2717667e-06, "special": false, "text": " the" }, { "id": 3640, - "logprob": -0.0059509277, + "logprob": -0.0052490234, "special": false, "text": " United" }, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json index afbfba30a..cd1a598e6 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json @@ -1,11 +1,11 @@ { "choices": [ { - "finish_reason": "stop", + "finish_reason": "length", "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?", + "content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1741965892, + "created": 1744396706, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { - "completion_tokens": 98, + "completion_tokens": 100, "prompt_tokens": 277, - "total_tokens": 375 + "total_tokens": 377 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json index 1b97d2615..a1d3ae782 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?", + "content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1741966313, + "created": 1744396703, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { - "completion_tokens": 67, + "completion_tokens": 78, "prompt_tokens": 277, - "total_tokens": 344 + "total_tokens": 355 } } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json index cd786b3ce..a839d7aac 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json @@ -13,11 +13,11 @@ "usage": null } ], - "created": 1741964480, + "created": 1744396699, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { "completion_tokens": 74, "prompt_tokens": 275, diff --git a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json index 5ed2c4507..c7215c930 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json @@ -13,11 +13,11 @@ "usage": null } ], - "created": 1741964477, + "created": 1744396697, "id": "", "model": "google/gemma-3-4b-it", "object": "chat.completion", - "system_fingerprint": "3.2.1-dev0-native", + "system_fingerprint": "3.2.3-dev0-native", "usage": { "completion_tokens": 75, "prompt_tokens": 279, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c169a78ce..2fbb9c12a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -260,11 +260,22 @@ struct Config { impl Config { fn get_head_dim(&self) -> Option { - self.head_dim.or_else(|| { - self.text_config - .as_ref() - .and_then(|text_config| text_config.head_dim) - }) + if let Some(head_dim) = self.head_dim { + return Some(head_dim); + } + + let text_config = self.text_config.as_ref()?; + if let Some(head_size) = text_config.head_dim { + return Some(head_size); + } + + match self.model_type.as_deref() { + // We special-case gemma3 here, since we need flashinfer for + // handling bidirectional masks. And flashinfer can only be + // used when the head size is known. + Some("gemma3") => Some(256), + _ => None, + } } fn flop(&self) -> Option { diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index 9479b6067..f78475d51 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state( state: flashinfer.BatchPrefillWithPagedKVCacheWrapper, block_tables: torch.Tensor, cu_seqlens: torch.Tensor, + custom_mask: Optional[torch.Tensor], input_lengths: torch.Tensor, num_heads: int, num_kv_heads: int, @@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state( paged_kv_indptr=indptr, paged_kv_indices=block_tables, paged_kv_last_page_len=last_page_len, + custom_mask=custom_mask, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 70fe9a3db..58afd6430 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.models.globals import ATTENTION from text_generation_server.utils.weights import UnquantizedWeight from transformers.activations import ACT2FN from text_generation_server.layers.attention import ( @@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if attention_mask is None: + if attention_mask is None or ATTENTION == "flashinfer": # flash attention attn_output = attention( query=query, @@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module): ) def get_attention_mask( - self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask + self, + input_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + dtype: torch.dtype, + bool_mask: bool = False, ): + image_token_mask = (input_ids == self.config.image_token_index).to( + input_ids.device + ) + device = input_ids.device min_dtype = torch.finfo(dtype).min @@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module): ) full_attention_mask[:, :, :, :sequence_length] = combined_mask - final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) - - return final_attention_mask + if bool_mask: + return full_attention_mask + else: + return torch.where(full_attention_mask, 0, min_dtype).to(device) def forward( self, @@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module): ) attention_mask = self.get_attention_mask( input_ids, - max_s, cu_seqlen_prefill, inputs_embeds.dtype, - image_token_mask, ) # Use flash attention for text-only input # else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7c5a374b..a28ef3810 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2434,6 +2434,7 @@ class FlashCausalLM(Model): input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, state: Optional[Any] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> ContextManager: if ATTENTION != "flashinfer": return nullcontext() @@ -2450,6 +2451,7 @@ class FlashCausalLM(Model): ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, + custom_mask=attention_mask, input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5f8eb9060..2b1e01dfa 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -485,6 +485,14 @@ class VlmCausalLM(FlashCausalLM): ) batch.position_ids = position_ids + if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None: + # Get the mask, needed for flashinfer. + attention_mask = self.model.get_attention_mask( + input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True + ).reshape(-1) + else: + attention_mask = None + # Try to find an associated cuda graph bs = input_ids.shape[0] sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) @@ -508,6 +516,7 @@ class VlmCausalLM(FlashCausalLM): cu_seqlen_prefill=cu_seqlen_prefill, input_lengths_tensor=input_lengths, cache_lengths_tensor=cache_lengths_tensor, + attention_mask=attention_mask, ): seqlen = Seqlen( input_lengths=input_lengths,