mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Support flashinfer for Gemma3 prefill (#3167)
* launcher: ensure correct detection of Gemma 3 head size * Support flashinfer for Gemma3 prefill Gemma3 uses bidirectional attention for images. Flashinfer supports custom masks. Hook up the mask with flashinfer, so that we do not have to use the slower SDPA implementation for prefills with images. * Update Gemma3 test outputs * Fixed unused import
This commit is contained in:
parent
4645678ff0
commit
84ab88d843
@ -8,61 +8,61 @@
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1331,
|
||||
"logprob": -0.34960938,
|
||||
"logprob": -0.31835938,
|
||||
"special": false,
|
||||
"text": " people"
|
||||
},
|
||||
{
|
||||
"id": 8390,
|
||||
"logprob": -0.14746094,
|
||||
"logprob": -0.1484375,
|
||||
"special": false,
|
||||
"text": " died"
|
||||
},
|
||||
{
|
||||
"id": 528,
|
||||
"logprob": -1.2265625,
|
||||
"logprob": -1.1171875,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.47070312,
|
||||
"logprob": -0.45898438,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3640,
|
||||
"logprob": -0.5859375,
|
||||
"logprob": -0.55859375,
|
||||
"special": false,
|
||||
"text": " United"
|
||||
},
|
||||
{
|
||||
"id": 4184,
|
||||
"logprob": -0.0027770996,
|
||||
"logprob": -0.0026397705,
|
||||
"special": false,
|
||||
"text": " States"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.34765625,
|
||||
"logprob": -0.38085938,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.0859375,
|
||||
"logprob": -0.07421875,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 818,
|
||||
"logprob": -1.1640625,
|
||||
"logprob": -1.0859375,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 6816,
|
||||
"logprob": -1.890625,
|
||||
"logprob": -1.75,
|
||||
"special": false,
|
||||
"text": " generally"
|
||||
},
|
||||
@ -74,7 +74,7 @@
|
||||
},
|
||||
{
|
||||
"id": 10967,
|
||||
"logprob": -0.90625,
|
||||
"logprob": -0.9609375,
|
||||
"special": false,
|
||||
"text": " estimate"
|
||||
},
|
||||
@ -86,43 +86,43 @@
|
||||
},
|
||||
{
|
||||
"id": 600,
|
||||
"logprob": -0.65234375,
|
||||
"logprob": -0.703125,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 236743,
|
||||
"logprob": -1.2109375,
|
||||
"logprob": -1.171875,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 236825,
|
||||
"logprob": -0.00088119507,
|
||||
"logprob": -0.0009918213,
|
||||
"special": false,
|
||||
"text": "6"
|
||||
},
|
||||
{
|
||||
"id": 236832,
|
||||
"logprob": -6.580353e-05,
|
||||
"logprob": -6.389618e-05,
|
||||
"special": false,
|
||||
"text": "7"
|
||||
},
|
||||
{
|
||||
"id": 236810,
|
||||
"logprob": -5.2690506e-05,
|
||||
"logprob": -4.7445297e-05,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 236764,
|
||||
"logprob": -0.0001745224,
|
||||
"logprob": -0.00017929077,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 236771,
|
||||
"logprob": -1.180172e-05,
|
||||
"logprob": -1.4901161e-05,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
@ -140,7 +140,7 @@
|
||||
},
|
||||
{
|
||||
"id": 1331,
|
||||
"logprob": -0.44921875,
|
||||
"logprob": -0.45898438,
|
||||
"special": false,
|
||||
"text": " people"
|
||||
},
|
||||
@ -158,49 +158,49 @@
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.00034713745,
|
||||
"logprob": -0.00032615662,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3640,
|
||||
"logprob": -0.028564453,
|
||||
"logprob": -0.029785156,
|
||||
"special": false,
|
||||
"text": " United"
|
||||
},
|
||||
{
|
||||
"id": 4184,
|
||||
"logprob": -0.00012207031,
|
||||
"logprob": -0.00012302399,
|
||||
"special": false,
|
||||
"text": " States"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -1.15625,
|
||||
"logprob": -1.1796875,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 3153,
|
||||
"logprob": -0.103027344,
|
||||
"logprob": -0.09667969,
|
||||
"special": false,
|
||||
"text": " However"
|
||||
},
|
||||
{
|
||||
"id": 236764,
|
||||
"logprob": -0.009155273,
|
||||
"logprob": -0.009094238,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1070,
|
||||
"logprob": -0.92578125,
|
||||
"logprob": -0.91015625,
|
||||
"special": false,
|
||||
"text": " some"
|
||||
},
|
||||
{
|
||||
"id": 61806,
|
||||
"logprob": -0.91796875,
|
||||
"logprob": -0.859375,
|
||||
"special": false,
|
||||
"text": " historians"
|
||||
},
|
||||
@ -218,79 +218,79 @@
|
||||
},
|
||||
{
|
||||
"id": 5396,
|
||||
"logprob": -0.8046875,
|
||||
"logprob": -0.765625,
|
||||
"special": false,
|
||||
"text": " actual"
|
||||
},
|
||||
{
|
||||
"id": 1548,
|
||||
"logprob": -0.04321289,
|
||||
"logprob": -0.048339844,
|
||||
"special": false,
|
||||
"text": " number"
|
||||
},
|
||||
{
|
||||
"id": 1451,
|
||||
"logprob": -0.66015625,
|
||||
"logprob": -0.65625,
|
||||
"special": false,
|
||||
"text": " could"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.091308594,
|
||||
"logprob": -0.09082031,
|
||||
"special": false,
|
||||
"text": " be"
|
||||
},
|
||||
{
|
||||
"id": 618,
|
||||
"logprob": -0.57421875,
|
||||
"logprob": -0.625,
|
||||
"special": false,
|
||||
"text": " as"
|
||||
},
|
||||
{
|
||||
"id": 1494,
|
||||
"logprob": -0.00036239624,
|
||||
"logprob": -0.00037193298,
|
||||
"special": false,
|
||||
"text": " high"
|
||||
},
|
||||
{
|
||||
"id": 618,
|
||||
"logprob": -0.0001335144,
|
||||
"logprob": -0.0001296997,
|
||||
"special": false,
|
||||
"text": " as"
|
||||
},
|
||||
{
|
||||
"id": 236743,
|
||||
"logprob": -0.0009689331,
|
||||
"logprob": -0.00093460083,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 236770,
|
||||
"logprob": -0.26367188,
|
||||
"logprob": -0.21289062,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 236771,
|
||||
"logprob": -0.17773438,
|
||||
"logprob": -0.16796875,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 3625,
|
||||
"logprob": -0.012084961,
|
||||
"logprob": -0.0126953125,
|
||||
"special": false,
|
||||
"text": " million"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.21289062,
|
||||
"logprob": -0.22460938,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.37304688,
|
||||
"logprob": -0.3984375,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
@ -302,13 +302,13 @@
|
||||
},
|
||||
{
|
||||
"id": 1006,
|
||||
"logprob": -1.3203125,
|
||||
"logprob": -1.359375,
|
||||
"special": false,
|
||||
"text": " am"
|
||||
},
|
||||
{
|
||||
"id": 3182,
|
||||
"logprob": -1.078125,
|
||||
"logprob": -1.0859375,
|
||||
"special": false,
|
||||
"text": " looking"
|
||||
},
|
||||
@ -320,85 +320,85 @@
|
||||
},
|
||||
{
|
||||
"id": 919,
|
||||
"logprob": -1.25,
|
||||
"logprob": -1.2578125,
|
||||
"special": false,
|
||||
"text": " more"
|
||||
},
|
||||
{
|
||||
"id": 1938,
|
||||
"logprob": -1.2421875,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " information"
|
||||
},
|
||||
{
|
||||
"id": 580,
|
||||
"logprob": -0.7734375,
|
||||
"logprob": -0.7421875,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 672,
|
||||
"logprob": -0.73046875,
|
||||
"logprob": -0.78125,
|
||||
"special": false,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 59725,
|
||||
"logprob": -0.75,
|
||||
"logprob": -0.7109375,
|
||||
"special": false,
|
||||
"text": " discrepancy"
|
||||
},
|
||||
{
|
||||
"id": 532,
|
||||
"logprob": -0.83984375,
|
||||
"logprob": -0.8046875,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.7109375,
|
||||
"logprob": -0.71484375,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5872,
|
||||
"logprob": -1.2734375,
|
||||
"logprob": -1.1640625,
|
||||
"special": false,
|
||||
"text": " factors"
|
||||
},
|
||||
{
|
||||
"id": 600,
|
||||
"logprob": -0.22851562,
|
||||
"logprob": -0.20410156,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 19263,
|
||||
"logprob": -1.1640625,
|
||||
"logprob": -1.1484375,
|
||||
"special": false,
|
||||
"text": " contributed"
|
||||
},
|
||||
{
|
||||
"id": 531,
|
||||
"logprob": -0.0010757446,
|
||||
"logprob": -0.000957489,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.18945312,
|
||||
"logprob": -0.19921875,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5777,
|
||||
"logprob": -1.2734375,
|
||||
"logprob": -1.171875,
|
||||
"special": false,
|
||||
"text": " wide"
|
||||
},
|
||||
{
|
||||
"id": 2644,
|
||||
"logprob": -0.01940918,
|
||||
"logprob": -0.020141602,
|
||||
"special": false,
|
||||
"text": " range"
|
||||
},
|
||||
@ -410,31 +410,31 @@
|
||||
},
|
||||
{
|
||||
"id": 14287,
|
||||
"logprob": -0.032470703,
|
||||
"logprob": -0.03564453,
|
||||
"special": false,
|
||||
"text": " estimates"
|
||||
},
|
||||
{
|
||||
"id": 236761,
|
||||
"logprob": -0.010375977,
|
||||
"logprob": -0.010620117,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 108,
|
||||
"logprob": -0.06591797,
|
||||
"logprob": -0.060302734,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 8291,
|
||||
"logprob": -0.8046875,
|
||||
"logprob": -0.7421875,
|
||||
"special": false,
|
||||
"text": "Here"
|
||||
},
|
||||
{
|
||||
"id": 236789,
|
||||
"logprob": -0.23828125,
|
||||
"logprob": -0.24023438,
|
||||
"special": false,
|
||||
"text": "'"
|
||||
},
|
||||
@ -446,55 +446,55 @@
|
||||
},
|
||||
{
|
||||
"id": 496,
|
||||
"logprob": -0.17480469,
|
||||
"logprob": -0.16992188,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 25890,
|
||||
"logprob": -0.087402344,
|
||||
"logprob": -0.06933594,
|
||||
"special": false,
|
||||
"text": " breakdown"
|
||||
},
|
||||
{
|
||||
"id": 529,
|
||||
"logprob": -0.0021209717,
|
||||
"logprob": -0.002243042,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.19140625,
|
||||
"logprob": -0.18554688,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5872,
|
||||
"logprob": -1.0078125,
|
||||
"logprob": -0.9921875,
|
||||
"special": false,
|
||||
"text": " factors"
|
||||
},
|
||||
{
|
||||
"id": 20894,
|
||||
"logprob": -0.26367188,
|
||||
"logprob": -0.25976562,
|
||||
"special": false,
|
||||
"text": " contributing"
|
||||
},
|
||||
{
|
||||
"id": 531,
|
||||
"logprob": -9.250641e-05,
|
||||
"logprob": -8.440018e-05,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.008666992,
|
||||
"logprob": -0.009765625,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 5777,
|
||||
"logprob": -0.6171875,
|
||||
"logprob": -0.67578125,
|
||||
"special": false,
|
||||
"text": " wide"
|
||||
},
|
||||
@ -506,31 +506,31 @@
|
||||
},
|
||||
{
|
||||
"id": 529,
|
||||
"logprob": -0.016723633,
|
||||
"logprob": -0.014831543,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 14287,
|
||||
"logprob": -0.011352539,
|
||||
"logprob": -0.012329102,
|
||||
"special": false,
|
||||
"text": " estimates"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": -0.30664062,
|
||||
"logprob": -0.3125,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -0.21386719,
|
||||
"logprob": -0.21484375,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 236743,
|
||||
"logprob": -0.35351562,
|
||||
"logprob": -0.43359375,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
@ -560,43 +560,43 @@
|
||||
},
|
||||
{
|
||||
"id": 7745,
|
||||
"logprob": -0.70703125,
|
||||
"logprob": -0.703125,
|
||||
"special": false,
|
||||
"text": " flu"
|
||||
},
|
||||
{
|
||||
"id": 10248,
|
||||
"logprob": -0.015258789,
|
||||
"logprob": -0.013427734,
|
||||
"special": false,
|
||||
"text": " pandemic"
|
||||
},
|
||||
{
|
||||
"id": 4355,
|
||||
"logprob": -0.83203125,
|
||||
"logprob": -0.6953125,
|
||||
"special": false,
|
||||
"text": " death"
|
||||
},
|
||||
{
|
||||
"id": 25363,
|
||||
"logprob": -7.43866e-05,
|
||||
"logprob": -6.771088e-05,
|
||||
"special": false,
|
||||
"text": " toll"
|
||||
},
|
||||
{
|
||||
"id": 528,
|
||||
"logprob": -0.08496094,
|
||||
"logprob": -0.076171875,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 506,
|
||||
"logprob": -6.67572e-06,
|
||||
"logprob": -7.2717667e-06,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3640,
|
||||
"logprob": -0.0059509277,
|
||||
"logprob": -0.0052490234,
|
||||
"special": false,
|
||||
"text": " United"
|
||||
},
|
||||
|
@ -1,11 +1,11 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?",
|
||||
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
@ -13,14 +13,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741965892,
|
||||
"created": 1744396706,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 98,
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 277,
|
||||
"total_tokens": 375
|
||||
"total_tokens": 377
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?",
|
||||
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
@ -13,14 +13,14 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741966313,
|
||||
"created": 1744396703,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 67,
|
||||
"completion_tokens": 78,
|
||||
"prompt_tokens": 277,
|
||||
"total_tokens": 344
|
||||
"total_tokens": 355
|
||||
}
|
||||
}
|
||||
|
@ -13,11 +13,11 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741964480,
|
||||
"created": 1744396699,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 74,
|
||||
"prompt_tokens": 275,
|
||||
|
@ -13,11 +13,11 @@
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1741964477,
|
||||
"created": 1744396697,
|
||||
"id": "",
|
||||
"model": "google/gemma-3-4b-it",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "3.2.1-dev0-native",
|
||||
"system_fingerprint": "3.2.3-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 75,
|
||||
"prompt_tokens": 279,
|
||||
|
@ -260,11 +260,22 @@ struct Config {
|
||||
|
||||
impl Config {
|
||||
fn get_head_dim(&self) -> Option<usize> {
|
||||
self.head_dim.or_else(|| {
|
||||
self.text_config
|
||||
.as_ref()
|
||||
.and_then(|text_config| text_config.head_dim)
|
||||
})
|
||||
if let Some(head_dim) = self.head_dim {
|
||||
return Some(head_dim);
|
||||
}
|
||||
|
||||
let text_config = self.text_config.as_ref()?;
|
||||
if let Some(head_size) = text_config.head_dim {
|
||||
return Some(head_size);
|
||||
}
|
||||
|
||||
match self.model_type.as_deref() {
|
||||
// We special-case gemma3 here, since we need flashinfer for
|
||||
// handling bidirectional masks. And flashinfer can only be
|
||||
// used when the head size is known.
|
||||
Some("gemma3") => Some(256),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn flop(&self) -> Option<u64> {
|
||||
|
@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
|
||||
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
custom_mask: Optional[torch.Tensor],
|
||||
input_lengths: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
|
||||
paged_kv_indptr=indptr,
|
||||
paged_kv_indices=block_tables,
|
||||
paged_kv_last_page_len=last_page_len,
|
||||
custom_mask=custom_mask,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
|
@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
from transformers.activations import ACT2FN
|
||||
from text_generation_server.layers.attention import (
|
||||
@ -248,7 +249,7 @@ class FlashGemma3Attention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
if attention_mask is None:
|
||||
if attention_mask is None or ATTENTION == "flashinfer":
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
@ -701,8 +702,16 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
def get_attention_mask(
|
||||
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
dtype: torch.dtype,
|
||||
bool_mask: bool = False,
|
||||
):
|
||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||
input_ids.device
|
||||
)
|
||||
|
||||
device = input_ids.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
@ -748,9 +757,10 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||
|
||||
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||
|
||||
return final_attention_mask
|
||||
if bool_mask:
|
||||
return full_attention_mask
|
||||
else:
|
||||
return torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -793,10 +803,8 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
attention_mask = self.get_attention_mask(
|
||||
input_ids,
|
||||
max_s,
|
||||
cu_seqlen_prefill,
|
||||
inputs_embeds.dtype,
|
||||
image_token_mask,
|
||||
)
|
||||
# Use flash attention for text-only input
|
||||
# else:
|
||||
|
@ -2434,6 +2434,7 @@ class FlashCausalLM(Model):
|
||||
input_lengths_tensor: torch.Tensor,
|
||||
cache_lengths_tensor: torch.Tensor,
|
||||
state: Optional[Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> ContextManager:
|
||||
if ATTENTION != "flashinfer":
|
||||
return nullcontext()
|
||||
@ -2450,6 +2451,7 @@ class FlashCausalLM(Model):
|
||||
),
|
||||
block_tables=block_tables,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
custom_mask=attention_mask,
|
||||
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
|
@ -485,6 +485,14 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
|
||||
# Get the mask, needed for flashinfer.
|
||||
attention_mask = self.model.get_attention_mask(
|
||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||
).reshape(-1)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
@ -508,6 +516,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
attention_mask=attention_mask,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
|
Loading…
Reference in New Issue
Block a user