mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Add gemma3 model (#3099)
This commit is contained in:
parent
f74c36fe0d
commit
ed46c2c414
@ -14,6 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||||
|
- [Gemma3](https://huggingface.co/collections/google/gemma-3)
|
||||||
|
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3)
|
||||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||||
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
||||||
|
@ -0,0 +1,133 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 20,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.44726562,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.011413574,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236812,
|
||||||
|
"logprob": -0.09814453,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.044189453,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.15625,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236810,
|
||||||
|
"logprob": -0.010864258,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.040039062,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.26757812,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236825,
|
||||||
|
"logprob": -0.0047302246,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.026123047,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.265625,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236832,
|
||||||
|
"logprob": -0.014160156,
|
||||||
|
"special": false,
|
||||||
|
"text": "7"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.013977051,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.103515625,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236828,
|
||||||
|
"logprob": -0.008178711,
|
||||||
|
"special": false,
|
||||||
|
"text": "8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.030151367,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.39453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236819,
|
||||||
|
"logprob": -0.008728027,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.020629883,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.08154297,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ", 4, 5, 6, 7, 8, 9, "
|
||||||
|
}
|
@ -0,0 +1,613 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 100,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 1331,
|
||||||
|
"logprob": -0.32421875,
|
||||||
|
"special": false,
|
||||||
|
"text": " people"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8390,
|
||||||
|
"logprob": -0.15332031,
|
||||||
|
"special": false,
|
||||||
|
"text": " died"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 528,
|
||||||
|
"logprob": -1.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.42578125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3640,
|
||||||
|
"logprob": -0.64453125,
|
||||||
|
"special": false,
|
||||||
|
"text": " United"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4184,
|
||||||
|
"logprob": -0.0027770996,
|
||||||
|
"special": false,
|
||||||
|
"text": " States"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236761,
|
||||||
|
"logprob": -0.37890625,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.08300781,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 818,
|
||||||
|
"logprob": -1.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": "The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6816,
|
||||||
|
"logprob": -1.765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " generally"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10951,
|
||||||
|
"logprob": -0.14550781,
|
||||||
|
"special": false,
|
||||||
|
"text": " accepted"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10967,
|
||||||
|
"logprob": -0.90625,
|
||||||
|
"special": false,
|
||||||
|
"text": " estimate"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 563,
|
||||||
|
"logprob": -0.49414062,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 600,
|
||||||
|
"logprob": -0.65625,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -1.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236825,
|
||||||
|
"logprob": -0.0009918213,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236832,
|
||||||
|
"logprob": -6.532669e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "7"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236810,
|
||||||
|
"logprob": -4.863739e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.00017929077,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236771,
|
||||||
|
"logprob": -1.2397766e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236771,
|
||||||
|
"logprob": -2.1457672e-06,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236771,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1331,
|
||||||
|
"logprob": -0.50390625,
|
||||||
|
"special": false,
|
||||||
|
"text": " people"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8390,
|
||||||
|
"logprob": -0.011474609,
|
||||||
|
"special": false,
|
||||||
|
"text": " died"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 528,
|
||||||
|
"logprob": -0.08496094,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.0003299713,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3640,
|
||||||
|
"logprob": -0.028442383,
|
||||||
|
"special": false,
|
||||||
|
"text": " United"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4184,
|
||||||
|
"logprob": -0.00011014938,
|
||||||
|
"special": false,
|
||||||
|
"text": " States"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236761,
|
||||||
|
"logprob": -1.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3153,
|
||||||
|
"logprob": -0.104003906,
|
||||||
|
"special": false,
|
||||||
|
"text": " However"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236764,
|
||||||
|
"logprob": -0.009094238,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1070,
|
||||||
|
"logprob": -0.88671875,
|
||||||
|
"special": false,
|
||||||
|
"text": " some"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 61806,
|
||||||
|
"logprob": -0.84765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " historians"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4646,
|
||||||
|
"logprob": -1.34375,
|
||||||
|
"special": false,
|
||||||
|
"text": " believe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.59375,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5396,
|
||||||
|
"logprob": -0.8046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " actual"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1548,
|
||||||
|
"logprob": -0.04321289,
|
||||||
|
"special": false,
|
||||||
|
"text": " number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1451,
|
||||||
|
"logprob": -0.60546875,
|
||||||
|
"special": false,
|
||||||
|
"text": " could"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 577,
|
||||||
|
"logprob": -0.091308594,
|
||||||
|
"special": false,
|
||||||
|
"text": " be"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 618,
|
||||||
|
"logprob": -0.61328125,
|
||||||
|
"special": false,
|
||||||
|
"text": " as"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1494,
|
||||||
|
"logprob": -0.00033569336,
|
||||||
|
"special": false,
|
||||||
|
"text": " high"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 618,
|
||||||
|
"logprob": -0.0001411438,
|
||||||
|
"special": false,
|
||||||
|
"text": " as"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.001045227,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236770,
|
||||||
|
"logprob": -0.21289062,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236771,
|
||||||
|
"logprob": -0.13378906,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3625,
|
||||||
|
"logprob": -0.0087890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " million"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236761,
|
||||||
|
"logprob": -0.2109375,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.39453125,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236777,
|
||||||
|
"logprob": -1.1328125,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1006,
|
||||||
|
"logprob": -1.4140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " am"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3182,
|
||||||
|
"logprob": -1.15625,
|
||||||
|
"special": false,
|
||||||
|
"text": " looking"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.035888672,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 919,
|
||||||
|
"logprob": -1.2734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " more"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1938,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"special": false,
|
||||||
|
"text": " information"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 580,
|
||||||
|
"logprob": -0.7734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 672,
|
||||||
|
"logprob": -0.77734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " this"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 59725,
|
||||||
|
"logprob": -0.70703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " discrepancy"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 532,
|
||||||
|
"logprob": -0.8515625,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.65625,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5872,
|
||||||
|
"logprob": -1.15625,
|
||||||
|
"special": false,
|
||||||
|
"text": " factors"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 600,
|
||||||
|
"logprob": -0.2265625,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19263,
|
||||||
|
"logprob": -1.125,
|
||||||
|
"special": false,
|
||||||
|
"text": " contributed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 531,
|
||||||
|
"logprob": -0.001083374,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.2109375,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5777,
|
||||||
|
"logprob": -1.21875,
|
||||||
|
"special": false,
|
||||||
|
"text": " wide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2644,
|
||||||
|
"logprob": -0.018310547,
|
||||||
|
"special": false,
|
||||||
|
"text": " range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 529,
|
||||||
|
"logprob": -0.12988281,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14287,
|
||||||
|
"logprob": -0.03564453,
|
||||||
|
"special": false,
|
||||||
|
"text": " estimates"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236761,
|
||||||
|
"logprob": -0.010314941,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 108,
|
||||||
|
"logprob": -0.060546875,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8291,
|
||||||
|
"logprob": -0.734375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Here"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236789,
|
||||||
|
"logprob": -0.26367188,
|
||||||
|
"special": false,
|
||||||
|
"text": "'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236751,
|
||||||
|
"logprob": -1.1920929e-06,
|
||||||
|
"special": false,
|
||||||
|
"text": "s"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 496,
|
||||||
|
"logprob": -0.15527344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25890,
|
||||||
|
"logprob": -0.08886719,
|
||||||
|
"special": false,
|
||||||
|
"text": " breakdown"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 529,
|
||||||
|
"logprob": -0.0020446777,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.17871094,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5872,
|
||||||
|
"logprob": -0.90234375,
|
||||||
|
"special": false,
|
||||||
|
"text": " factors"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20894,
|
||||||
|
"logprob": -0.25976562,
|
||||||
|
"special": false,
|
||||||
|
"text": " contributing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 531,
|
||||||
|
"logprob": -8.34465e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.008544922,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5777,
|
||||||
|
"logprob": -0.62109375,
|
||||||
|
"special": false,
|
||||||
|
"text": " wide"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2644,
|
||||||
|
"logprob": -0.0023345947,
|
||||||
|
"special": false,
|
||||||
|
"text": " range"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 529,
|
||||||
|
"logprob": -0.016723633,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14287,
|
||||||
|
"logprob": -0.011291504,
|
||||||
|
"special": false,
|
||||||
|
"text": " estimates"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.29101562,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -0.21484375,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236743,
|
||||||
|
"logprob": -0.2890625,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236770,
|
||||||
|
"logprob": -3.5762787e-07,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236819,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236770,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 236828,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7745,
|
||||||
|
"logprob": -0.70703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " flu"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10248,
|
||||||
|
"logprob": -0.01953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " pandemic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4355,
|
||||||
|
"logprob": -0.78515625,
|
||||||
|
"special": false,
|
||||||
|
"text": " death"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25363,
|
||||||
|
"logprob": -6.771088e-05,
|
||||||
|
"special": false,
|
||||||
|
"text": " toll"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 528,
|
||||||
|
"logprob": -0.08496094,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 506,
|
||||||
|
"logprob": -7.033348e-06,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3640,
|
||||||
|
"logprob": -0.0067443848,
|
||||||
|
"special": false,
|
||||||
|
"text": " United"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4184,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " States"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a humorous and unexpected sight of a cow enjoying a tropical beach!",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741703756,
|
||||||
|
"id": "",
|
||||||
|
"model": "gg-hf-g/gemma-3-4b-it",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 70,
|
||||||
|
"prompt_tokens": 277,
|
||||||
|
"total_tokens": 347
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "Based on the image, the animal is a cow, not a dog! \n\nIt appears to be a **Brazilian cattle breed** known as a **Gir Cow**. They are recognized for their reddish-brown color and distinctive markings.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1741703753,
|
||||||
|
"id": "",
|
||||||
|
"model": "gg-hf-g/gemma-3-4b-it",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.1.2-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 48,
|
||||||
|
"prompt_tokens": 281,
|
||||||
|
"total_tokens": 329
|
||||||
|
}
|
||||||
|
}
|
90
integration-tests/models/test_flash_gemma3.py
Normal file
90
integration-tests/models/test_flash_gemma3.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_gemma3_handle(launcher):
|
||||||
|
with launcher("gg-hf-g/gemma-3-4b-it", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_gemma3(flash_gemma3_handle):
|
||||||
|
await flash_gemma3_handle.health(300)
|
||||||
|
return flash_gemma3_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_gemma3(flash_gemma3, response_snapshot):
|
||||||
|
response = await flash_gemma3.generate(
|
||||||
|
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||||
|
seed=42,
|
||||||
|
max_new_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 100
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_gemma3_image_cow_dog(flash_gemma3, response_snapshot):
|
||||||
|
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||||
|
response = await flash_gemma3.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is the breed of the dog in the image?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "Based on the image, the animal is a cow, not a dog! \n\nIt appears to be a **Brazilian cattle breed** known as a **Gir Cow**. They are recognized for their reddish-brown color and distinctive markings."
|
||||||
|
)
|
||||||
|
assert response.usage["completion_tokens"] == 48
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_gemma3_image_cow(flash_gemma3, response_snapshot):
|
||||||
|
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||||
|
response = await flash_gemma3.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a humorous and unexpected sight of a cow enjoying a tropical beach!"
|
||||||
|
)
|
||||||
|
assert response.usage["completion_tokens"] == 70
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_exceed_window(flash_gemma3, response_snapshot):
|
||||||
|
response = await flash_gemma3.generate(
|
||||||
|
"This is a nice place. " * 800 + "Now count: 1, 2, 3",
|
||||||
|
seed=42,
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.generated_text == ", 4, 5, 6, 7, 8, 9, "
|
||||||
|
assert response.details.generated_tokens == 20
|
||||||
|
assert response == response_snapshot
|
@ -2064,6 +2064,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let default_optimal = match config {
|
let default_optimal = match config {
|
||||||
Some(ref config) => match config.model_type.as_deref() {
|
Some(ref config) => match config.model_type.as_deref() {
|
||||||
Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
|
Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
|
||||||
|
Some("gemma3") => 8000,
|
||||||
_ => 4096,
|
_ => 4096,
|
||||||
},
|
},
|
||||||
None => 4096,
|
None => 4096,
|
||||||
|
@ -216,6 +216,19 @@ impl Qwen2_5Vl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Gemma3VisionConfig {
|
||||||
|
pub(crate) image_size: usize,
|
||||||
|
pub(crate) patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Gemma3 {
|
||||||
|
vision_config: Gemma3VisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@ -249,6 +262,7 @@ pub enum Config {
|
|||||||
Paligemma(Paligemma),
|
Paligemma(Paligemma),
|
||||||
Gemma,
|
Gemma,
|
||||||
Gemma2,
|
Gemma2,
|
||||||
|
Gemma3(Gemma3),
|
||||||
Cohere,
|
Cohere,
|
||||||
Drbx,
|
Drbx,
|
||||||
Falcon,
|
Falcon,
|
||||||
|
@ -33,7 +33,16 @@ impl ChatTemplate {
|
|||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
// enable things like .strip() or .capitalize()
|
// enable things like .strip() or .capitalize()
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
|
||||||
|
// TODO: replace with better solution
|
||||||
|
// hack to adjust gemma3 template for debug
|
||||||
|
// replace 'messages[0]['content'][0]['text']' with 'messages[0]['content']'
|
||||||
|
let mutated_template = template.replace(
|
||||||
|
"messages[0]['content'][0]['text']",
|
||||||
|
"messages[0]['content']",
|
||||||
|
);
|
||||||
|
|
||||||
|
let template_str = mutated_template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
env.add_function("strftime_now", strftime_now);
|
env.add_function("strftime_now", strftime_now);
|
||||||
tracing::debug!("Loading template: {}", template_str);
|
tracing::debug!("Loading template: {}", template_str);
|
||||||
@ -123,8 +132,8 @@ mod tests {
|
|||||||
use crate::infer::chat_template::{raise_exception, strftime_now};
|
use crate::infer::chat_template::{raise_exception, strftime_now};
|
||||||
use crate::infer::ChatTemplate;
|
use crate::infer::ChatTemplate;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage,
|
ChatTemplateInputs, Message, MessageBody, MessageChunk, MessageContent, TextMessage,
|
||||||
TokenizerConfigToken, Tool,
|
TokenizerConfigToken, Tool, Url,
|
||||||
};
|
};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
@ -1230,4 +1239,98 @@ TOOL CALL ID: 0
|
|||||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_with_special_system_prompt() {
|
||||||
|
// chat template from gemma3
|
||||||
|
let ct = ChatTemplate::new(
|
||||||
|
r#"{{ bos_token }}
|
||||||
|
{%- if messages[0]['role'] == 'system' -%}
|
||||||
|
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
|
||||||
|
|
||||||
|
' -%}
|
||||||
|
{%- set loop_messages = messages[1:] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set first_user_prefix = "" -%}
|
||||||
|
{%- set loop_messages = messages -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- for message in loop_messages -%}
|
||||||
|
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||||
|
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if (message['role'] == 'assistant') -%}
|
||||||
|
{%- set role = "model" -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set role = message['role'] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{ '<start_of_turn>' + role + '
|
||||||
|
' + (first_user_prefix if loop.first else "") }}
|
||||||
|
{%- if message['content'] is string -%}
|
||||||
|
{{ message['content'] | trim }}
|
||||||
|
{%- elif message['content'] is iterable -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item['type'] == 'image' -%}
|
||||||
|
{{ '<start_of_image>' }}
|
||||||
|
{%- elif item['type'] == 'text' -%}
|
||||||
|
{{ item['text'] | trim }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{{ '<end_of_turn>
|
||||||
|
' }}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{'<start_of_turn>model
|
||||||
|
'}}
|
||||||
|
{%- endif -%}
|
||||||
|
"#
|
||||||
|
.to_string(),
|
||||||
|
Some(TokenizerConfigToken::String("<bos>".to_string())),
|
||||||
|
Some(TokenizerConfigToken::String("</eos>".to_string())),
|
||||||
|
);
|
||||||
|
let msgs: Vec<Message> = vec![
|
||||||
|
Message {
|
||||||
|
name: None,
|
||||||
|
role: "system".to_string(),
|
||||||
|
body: MessageBody::Content {
|
||||||
|
content: MessageContent::MultipleChunks(vec![MessageChunk::Text {
|
||||||
|
text: "You are a helpful assistant.".to_string(),
|
||||||
|
}]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
body: MessageBody::Content {
|
||||||
|
content: MessageContent::MultipleChunks(vec![
|
||||||
|
MessageChunk::Text {
|
||||||
|
text: "I'm already using this supplement ".to_string(),
|
||||||
|
},
|
||||||
|
MessageChunk::ImageUrl {
|
||||||
|
image_url: Url {
|
||||||
|
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG".to_string()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageChunk::Text {
|
||||||
|
text: "and I want to use this one too ".to_string()
|
||||||
|
},
|
||||||
|
MessageChunk::ImageUrl {
|
||||||
|
image_url: Url {
|
||||||
|
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg".to_string()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageChunk::Text {
|
||||||
|
text: " what are cautions?".to_string()
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = ct.apply(msgs, None);
|
||||||
|
let expected = "<bos><start_of_turn>user\nYou are a helpful assistant.\n\nI'm already using this supplement and I want to use this one too  what are cautions?<end_of_turn>\n<start_of_turn>model\n".to_string();
|
||||||
|
assert_eq!(result.unwrap(), expected);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,6 +152,11 @@ impl HubTokenizerConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct ChatTemplateStandalone {
|
||||||
|
pub chat_template: ChatTemplateVersions,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum TokenizerConfigToken {
|
pub enum TokenizerConfigToken {
|
||||||
@ -173,6 +178,7 @@ impl TokenizerConfigToken {
|
|||||||
pub enum HubPreprocessorConfig {
|
pub enum HubPreprocessorConfig {
|
||||||
Idefics2Processor(Idefics2Preprocessor),
|
Idefics2Processor(Idefics2Preprocessor),
|
||||||
Idefics3Processor(Idefics2Preprocessor),
|
Idefics3Processor(Idefics2Preprocessor),
|
||||||
|
Gemma3Processor(Gemma3Processor),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubPreprocessorConfig {
|
impl HubPreprocessorConfig {
|
||||||
@ -188,6 +194,12 @@ pub struct Idefics2Preprocessor {
|
|||||||
do_image_splitting: bool,
|
do_image_splitting: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Gemma3Processor {
|
||||||
|
#[serde(default)]
|
||||||
|
do_image_splitting: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
pub struct HubProcessorConfig {
|
pub struct HubProcessorConfig {
|
||||||
pub chat_template: Option<ChatTemplateVersions>,
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
|
@ -1532,6 +1532,7 @@ pub async fn run(
|
|||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
preprocessor_config_filename,
|
||||||
processor_config_filename,
|
processor_config_filename,
|
||||||
|
chat_template_filename,
|
||||||
model_info,
|
model_info,
|
||||||
) = match api {
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
@ -1539,6 +1540,7 @@ pub async fn run(
|
|||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
Some(local_path.join("processor_config.json")),
|
Some(local_path.join("processor_config.json")),
|
||||||
|
Some(local_path.join("chat_template.json")),
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
Type::Api(api) => {
|
Type::Api(api) => {
|
||||||
@ -1552,6 +1554,7 @@ pub async fn run(
|
|||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
let chat_template_filename = api_repo.get("chat_template.json").await.ok();
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||||
Some(model_info)
|
Some(model_info)
|
||||||
@ -1564,6 +1567,7 @@ pub async fn run(
|
|||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
preprocessor_config_filename,
|
||||||
processor_config_filename,
|
processor_config_filename,
|
||||||
|
chat_template_filename,
|
||||||
model_info,
|
model_info,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -1579,11 +1583,23 @@ pub async fn run(
|
|||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
repo.get("preprocessor_config.json"),
|
repo.get("preprocessor_config.json"),
|
||||||
repo.get("processor_config.json"),
|
repo.get("processor_config.json"),
|
||||||
|
repo.get("chat_template.json"),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// if chat_template_filename is present, load the chat template
|
||||||
|
let chat_template: Option<crate::ChatTemplateVersions> = chat_template_filename
|
||||||
|
.and_then(|f| std::fs::read_to_string(f).ok())
|
||||||
|
.and_then(|c| {
|
||||||
|
let res = serde_json::from_str::<crate::ChatTemplateStandalone>(&c);
|
||||||
|
if let Err(e) = &res {
|
||||||
|
tracing::warn!("Could not parse chat template {e:?}");
|
||||||
|
}
|
||||||
|
res.ok().map(|t| t.chat_template)
|
||||||
|
});
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
|
tracing::warn!("Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}");
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
@ -1592,11 +1608,16 @@ pub async fn run(
|
|||||||
} else {
|
} else {
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
};
|
};
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
let mut tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if chat_template.is_some() {
|
||||||
|
tracing::info!("Using chat template from chat_template.json");
|
||||||
|
tokenizer_config.chat_template = chat_template;
|
||||||
|
}
|
||||||
|
|
||||||
let tokenizer: Result<Tokenizer, WebServerError> = {
|
let tokenizer: Result<Tokenizer, WebServerError> = {
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
Python::with_gil(|py| -> PyResult<()> {
|
Python::with_gil(|py| -> PyResult<()> {
|
||||||
|
@ -18,6 +18,7 @@ use std::sync::Arc;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
use tracing::warn;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
use {once_cell::sync::Lazy, regex::Regex};
|
use {once_cell::sync::Lazy, regex::Regex};
|
||||||
|
|
||||||
@ -694,6 +695,14 @@ fn image_tokens(
|
|||||||
"<|vision_start|>{:?}<|vision_end|>",
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
||||||
),
|
),
|
||||||
|
Gemma3(_config) => {
|
||||||
|
// TODO: prefer using the config to determine the number of features
|
||||||
|
let num_mm_soft_tokens_per_image = 256;
|
||||||
|
format!(
|
||||||
|
"\n\n<start_of_image>{:?}<end_of_image>\n\n",
|
||||||
|
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
|
||||||
|
)
|
||||||
|
}
|
||||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -721,8 +730,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
|
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Paligemma(_)
|
||||||
| Qwen2Vl(_) | Qwen2_5Vl(_)),
|
| LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
|
||||||
) => {
|
) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
@ -106,6 +106,17 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||||
FlashGemma2ForCausalLM,
|
FlashGemma2ForCausalLM,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
|
||||||
|
FlashGemma3ForCausalLM,
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.gemma3.processing_gemma3 import (
|
||||||
|
Gemma3Processor,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.gemma3.configuration_gemma3 import (
|
||||||
|
Gemma3Config,
|
||||||
|
Gemma3TextConfig,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||||
FlashDbrxForCausalLM,
|
FlashDbrxForCausalLM,
|
||||||
DbrxConfig,
|
DbrxConfig,
|
||||||
@ -258,6 +269,16 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Gemma2",
|
"name": "Gemma2",
|
||||||
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
||||||
}
|
}
|
||||||
|
GEMMA3 = {
|
||||||
|
"type": "gemma3",
|
||||||
|
"name": "Gemma3",
|
||||||
|
"url": "https://huggingface.co/collections/google/gemma-3",
|
||||||
|
}
|
||||||
|
GEMMA3_TEXT = {
|
||||||
|
"type": "gemma3_text",
|
||||||
|
"name": "Gemma3 Text",
|
||||||
|
"url": "https://huggingface.co/collections/google/gemma-3",
|
||||||
|
}
|
||||||
COHERE = {
|
COHERE = {
|
||||||
"type": "cohere",
|
"type": "cohere",
|
||||||
"name": "Cohere",
|
"name": "Cohere",
|
||||||
@ -1094,6 +1115,83 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif model_type == GEMMA3_TEXT:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashGemma3ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# TODO: once implemented in transformers, use the config class
|
||||||
|
# and processor class from there.
|
||||||
|
config_class=Gemma3TextConfig,
|
||||||
|
# Works better for these models
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
elif model_type == GEMMA3:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
# TODO: Use VlmCausalLM when image support is added.
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Gemma3ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
# TODO: once implemented in transformers, use the config class
|
||||||
|
# and processor class from there.
|
||||||
|
config_class=Gemma3Config,
|
||||||
|
processor_class=Gemma3Processor,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == COHERE:
|
if model_type == COHERE:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -0,0 +1,922 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
#
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
from text_generation_server.layers.layernorm import (
|
||||||
|
FastRMSNorm,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
Seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ATTENTION_TYPE_GLOBAL = "global"
|
||||||
|
ATTENTION_TYPE_LOCAL = "local_sliding"
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3FastRMSNorm(FastRMSNorm):
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix: str, weights, eps=1e-6):
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
|
weights.dtype = dtype
|
||||||
|
new = cls(weight, eps)
|
||||||
|
new.dtype = dtype
|
||||||
|
return new
|
||||||
|
|
||||||
|
# perform the multiplication in full precision and downcast after
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
return hidden_states.to(self.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix: str, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, UnquantizedWeight):
|
||||||
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.head_dim
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma3Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = config.head_dim
|
||||||
|
self.causal = causal
|
||||||
|
if is_sliding:
|
||||||
|
self.window_size = config.sliding_window
|
||||||
|
# TODO: remove this hack to support local sliding window
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config.rope_scaling = dict(rope_type="default")
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=config.head_dim,
|
||||||
|
base=config.rope_local_base_freq,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.window_size = -1
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=config.head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = (
|
||||||
|
config.query_pre_attn_scalar**-0.5
|
||||||
|
if config.query_pre_attn_scalar is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.softcap = None # config.attn_logit_softcapping
|
||||||
|
|
||||||
|
query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
|
query_key_value,
|
||||||
|
layer_id,
|
||||||
|
["q_proj", "k_proj", "v_proj"],
|
||||||
|
sizes=[
|
||||||
|
self.head_size * config.num_attention_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
layer_id,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
self.q_norm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.k_norm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.enable_gqa = self.num_heads != self.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
adapter_data,
|
||||||
|
attention_mask,
|
||||||
|
):
|
||||||
|
|
||||||
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
|
||||||
|
key = kv[:, 0]
|
||||||
|
value = kv[:, 1]
|
||||||
|
|
||||||
|
query = query.reshape(-1, self.head_size)
|
||||||
|
key = key.reshape(-1, self.head_size)
|
||||||
|
|
||||||
|
query, _ = self.q_norm(query.contiguous())
|
||||||
|
key, _ = self.k_norm(key.contiguous())
|
||||||
|
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
if attention_mask is None:
|
||||||
|
# flash attention
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
window_size_left=self.window_size,
|
||||||
|
softcap=self.softcap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
|
||||||
|
|
||||||
|
# Split tensors using vectorized split
|
||||||
|
query_list = torch.split(query, lengths.tolist(), dim=0)
|
||||||
|
key_list = torch.split(key, lengths.tolist(), dim=0)
|
||||||
|
value_list = torch.split(value, lengths.tolist(), dim=0)
|
||||||
|
|
||||||
|
padded_query = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
query_list, batch_first=True
|
||||||
|
)
|
||||||
|
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
|
||||||
|
padded_value = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
value_list, batch_first=True
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_query = padded_query.transpose(1, 2).contiguous()
|
||||||
|
padded_key = padded_key.transpose(1, 2).contiguous()
|
||||||
|
padded_value = padded_value.transpose(1, 2).contiguous()
|
||||||
|
zeros_to_add = torch.zeros(
|
||||||
|
padded_key.size(0),
|
||||||
|
self.num_key_value_heads,
|
||||||
|
1,
|
||||||
|
self.head_size,
|
||||||
|
dtype=padded_key.dtype,
|
||||||
|
device=padded_key.device,
|
||||||
|
)
|
||||||
|
key_states = torch.cat([padded_key, zeros_to_add], dim=2)
|
||||||
|
value_states = torch.cat([padded_value, zeros_to_add], dim=2)
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
padded_query,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
scale=self.softmax_scale,
|
||||||
|
enable_gqa=self.enable_gqa,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(
|
||||||
|
1, 2
|
||||||
|
) # [batch_size, seq_len, num_heads, head_dim]
|
||||||
|
max_seq_len = padded_query.size(2)
|
||||||
|
seq_range = torch.arange(
|
||||||
|
max_seq_len, device=padded_query.device
|
||||||
|
).unsqueeze(0)
|
||||||
|
lengths_tensor = torch.tensor(
|
||||||
|
lengths, device=padded_query.device
|
||||||
|
).unsqueeze(1)
|
||||||
|
mask = seq_range < lengths_tensor # [batch, max_seq_len]
|
||||||
|
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
softcap=self.softcap,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_activation
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id,
|
||||||
|
["gate_proj", "up_proj"],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
layer_id,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, adapter_data):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma3Layer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = FlashGemma3Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
causal=causal,
|
||||||
|
is_sliding=is_sliding,
|
||||||
|
)
|
||||||
|
self.mlp = Gemma3MLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.pre_feedforward_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_feedforward_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
adapter_data,
|
||||||
|
attention_mask,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
adapter_data,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
|
||||||
|
normed_attn_res_output = normed_attn_res_output + res
|
||||||
|
res = normed_attn_res_output
|
||||||
|
|
||||||
|
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
||||||
|
mlp_output = self.mlp(pre_normed, adapter_data)
|
||||||
|
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
||||||
|
|
||||||
|
return post_hidden_states, normed_attn_res_output
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma3Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FlashGemma3Layer(
|
||||||
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
|
causal=causal,
|
||||||
|
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
cos, sin = self.layers[i].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply sliding window mask if needed
|
||||||
|
if layer.self_attn.window_size > 0 and attention_mask is not None:
|
||||||
|
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||||
|
# prefill may be larger than sliding window
|
||||||
|
effective_seq_len = max(
|
||||||
|
position_ids.shape[0], self.layers[i].self_attn.window_size
|
||||||
|
)
|
||||||
|
sliding_window_mask = torch.tril(
|
||||||
|
torch.ones_like(attention_mask, dtype=torch.bool),
|
||||||
|
diagonal=-self.layers[i].self_attn.window_size,
|
||||||
|
)
|
||||||
|
attention_mask = torch.where(
|
||||||
|
sliding_window_mask, min_dtype, attention_mask
|
||||||
|
)
|
||||||
|
offset = max(0, position_ids.shape[0] - effective_seq_len)
|
||||||
|
attention_mask = attention_mask[
|
||||||
|
:, :, offset : offset + effective_seq_len
|
||||||
|
]
|
||||||
|
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
adapter_data,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma3ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
embed_norm = config.hidden_size**0.5
|
||||||
|
if not prefix:
|
||||||
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
|
self.model = FlashGemma3Model(
|
||||||
|
prefix=prefix, config=config, weights=weights, causal=causal
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
prefix=(
|
||||||
|
f"{prefix}.embed_tokens"
|
||||||
|
if config.tie_word_embeddings
|
||||||
|
else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
# self.softcap = config.attn_logit_softcapping
|
||||||
|
# assert isinstance(self.softcap, float)
|
||||||
|
self.softcap = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_embeds,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MultimodalInputProjection(torch.nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mm_input_projection_weight = weights.get_tensor(
|
||||||
|
"multi_modal_projector.mm_input_projection_weight"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.mm_soft_emb_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.patches_per_image = int(
|
||||||
|
config.vision_config.image_size // config.vision_config.patch_size
|
||||||
|
)
|
||||||
|
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
||||||
|
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
||||||
|
self.avg_pool = nn.AvgPool2d(
|
||||||
|
kernel_size=self.kernel_size, stride=self.kernel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, vision_outputs: torch.Tensor):
|
||||||
|
batch_size, _, seq_length = vision_outputs.shape
|
||||||
|
|
||||||
|
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
||||||
|
batch_size, seq_length, self.patches_per_image, self.patches_per_image
|
||||||
|
)
|
||||||
|
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
||||||
|
|
||||||
|
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
||||||
|
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
||||||
|
|
||||||
|
normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
|
||||||
|
|
||||||
|
projected_vision_outputs = torch.matmul(
|
||||||
|
normed_vision_outputs, self.mm_input_projection_weight
|
||||||
|
)
|
||||||
|
return projected_vision_outputs.type_as(vision_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.vision_config is not None:
|
||||||
|
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
|
||||||
|
self.post_vision_model_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix="vision_tower.vision_model.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multimodal_projector = Gemma3MultimodalInputProjection(
|
||||||
|
prefix="multi_modal_projector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_config = config.text_config
|
||||||
|
text_config.speculator = config.speculator
|
||||||
|
text_config.quantize = config.quantize
|
||||||
|
|
||||||
|
self.vision_model = load_vision_model(
|
||||||
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
|
config=config.vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix=prefix,
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image_token_mask(self, input_ids):
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
start_token_id = self.config.boi_token_index
|
||||||
|
K = self.config.mm_tokens_per_image
|
||||||
|
|
||||||
|
mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device)
|
||||||
|
start_positions = (input_ids == start_token_id).nonzero(as_tuple=True)[0]
|
||||||
|
mask_indices = start_positions.unsqueeze(1) + torch.arange(
|
||||||
|
1, K + 1, device=device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
valid_mask = mask_indices < input_ids.size(0)
|
||||||
|
mask_indices = mask_indices[valid_mask]
|
||||||
|
mask[mask_indices] = True
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def get_attention_mask(
|
||||||
|
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
|
||||||
|
):
|
||||||
|
device = input_ids.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
|
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||||
|
batch_size = len(lengths)
|
||||||
|
|
||||||
|
sequence_length = max(lengths)
|
||||||
|
target_length = max_s
|
||||||
|
# Create the padding mask from the computed lengths.
|
||||||
|
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||||
|
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
||||||
|
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
|
||||||
|
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||||
|
|
||||||
|
# Build the base causal mask (for non-image tokens):
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones(
|
||||||
|
(sequence_length, sequence_length), dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
) # [batch, sequence_length, sequence_length]
|
||||||
|
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
|
||||||
|
|
||||||
|
image_token_mask = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
|
||||||
|
)
|
||||||
|
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine the causal base mask and the bidirectional mask.
|
||||||
|
combined_mask = torch.logical_or(
|
||||||
|
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
|
||||||
|
).to(device)
|
||||||
|
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
|
||||||
|
|
||||||
|
full_attention_mask = torch.zeros(
|
||||||
|
(batch_size, 1, sequence_length, target_length),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.bool,
|
||||||
|
)
|
||||||
|
full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
||||||
|
|
||||||
|
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
|
||||||
|
|
||||||
|
return final_attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
# Unused here
|
||||||
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
max_s += 1
|
||||||
|
position_ids += 1
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
|
image_outputs = self.vision_model(pixel_values)
|
||||||
|
vision_outputs = self.post_vision_model_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multimodal_projector(vision_outputs)
|
||||||
|
|
||||||
|
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||||
|
input_ids.device
|
||||||
|
)
|
||||||
|
inputs_embeds[image_token_mask] = image_features.view(
|
||||||
|
-1, image_features.shape[-1]
|
||||||
|
)
|
||||||
|
attention_mask = self.get_attention_mask(
|
||||||
|
input_ids,
|
||||||
|
max_s,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
inputs_embeds.dtype,
|
||||||
|
image_token_mask,
|
||||||
|
)
|
||||||
|
# Use flash attention for text-only input
|
||||||
|
# else:
|
||||||
|
# if cu_seqlen_prefill is not None:
|
||||||
|
# min_dtype = torch.finfo(inputs_embeds.dtype).min
|
||||||
|
# lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
|
||||||
|
|
||||||
|
# # Determine the maximum sequence length (after padding) from query.
|
||||||
|
# sequence_length = max(lengths)
|
||||||
|
# target_length = max_s
|
||||||
|
|
||||||
|
# # Create the padding mask from the computed lengths.
|
||||||
|
# # pad_mask: [batch, sequence_length] where True indicates valid tokens.
|
||||||
|
# seq_range = torch.arange(
|
||||||
|
# sequence_length, device=input_ids.device
|
||||||
|
# ).unsqueeze(0)
|
||||||
|
# lengths_tensor = torch.tensor(
|
||||||
|
# lengths, device=input_ids.device
|
||||||
|
# ).unsqueeze(1)
|
||||||
|
# pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
|
||||||
|
|
||||||
|
# # Build the base causal mask (for non-image tokens):
|
||||||
|
# causal_mask = torch.tril(
|
||||||
|
# torch.ones(
|
||||||
|
# (sequence_length, sequence_length),
|
||||||
|
# dtype=torch.bool,
|
||||||
|
# device=input_ids.device,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
|
||||||
|
# 1
|
||||||
|
# ) # [batch, sequence_length, sequence_length]
|
||||||
|
# base_mask = base_mask & causal_mask.unsqueeze(0)
|
||||||
|
# attention_mask = base_mask.unsqueeze(
|
||||||
|
# 1
|
||||||
|
# ) # [batch, 1, sequence_length, sequence_length]
|
||||||
|
# full_attention_mask = torch.zeros(
|
||||||
|
# (len(lengths), 1, sequence_length, target_length),
|
||||||
|
# device=input_ids.device,
|
||||||
|
# dtype=torch.bool,
|
||||||
|
# )
|
||||||
|
# full_attention_mask[:, :, :, :sequence_length] = attention_mask
|
||||||
|
|
||||||
|
# attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(
|
||||||
|
# input_ids.device
|
||||||
|
# )
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# pad logit with 1 zero logit for the image token
|
||||||
|
if pixel_values is not None:
|
||||||
|
logits = torch.cat(
|
||||||
|
[logits, torch.zeros(logits.size(0), 1, device=logits.device)], dim=1
|
||||||
|
)
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits = torch.cat(
|
||||||
|
[
|
||||||
|
speculative_logits,
|
||||||
|
torch.zeros(
|
||||||
|
speculative_logits.size(0),
|
||||||
|
1,
|
||||||
|
device=speculative_logits.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
prefix="vision_model" if not prefix else f"{prefix}.vision_model",
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,313 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_gemma3.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers import SiglipVisionConfig
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3TextConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Gemma3-4B.
|
||||||
|
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 262144):
|
||||||
|
Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`Gemma3Model`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2304):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 9216):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 26):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
head_dim (`int`, *optional*, defaults to 256):
|
||||||
|
The attention head dimension.
|
||||||
|
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window
|
||||||
|
attention. This is the size of the sliding window.
|
||||||
|
query_pre_attn_scalar (`float`, *optional*):
|
||||||
|
The scaling factor used on the attention scores, not that
|
||||||
|
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||||
|
The base period of the RoPE embeddings used for global attention.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||||
|
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||||
|
accordingly.
|
||||||
|
Expected contents:
|
||||||
|
`rope_type` (`str`):
|
||||||
|
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||||
|
'llama3'], with 'default' being the original RoPE implementation.
|
||||||
|
`factor` (`float`, *optional*):
|
||||||
|
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||||
|
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||||
|
original maximum pre-trained length.
|
||||||
|
`original_max_position_embeddings` (`int`, *optional*):
|
||||||
|
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||||
|
pretraining.
|
||||||
|
`attention_factor` (`float`, *optional*):
|
||||||
|
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||||
|
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||||
|
`factor` field to infer the suggested value.
|
||||||
|
`beta_fast` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 32.
|
||||||
|
`beta_slow` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 1.
|
||||||
|
`short_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`long_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`low_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||||
|
`high_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||||
|
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings for local attention.
|
||||||
|
sliding_window_pattern (`int`, *optional*, defaults to 6):
|
||||||
|
Pattern for the sliding window attention.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder. Will default to
|
||||||
|
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
|
||||||
|
activation function.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
Beginning of stream token id.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
final_logit_softcapping (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to apply logit softcapping or nor
|
||||||
|
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||||
|
Scaling factor when applying tanh soft-capping on the attention scorexs.
|
||||||
|
cache_implementation (`str`, *optional*, defaults to `"hybrid"`):
|
||||||
|
The cache type to be used with `generate`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma3Model, Gemma3TextConfig
|
||||||
|
>>> # Initializing a Gemma3 gemma3-4b style configuration
|
||||||
|
>>> configuration = Gemma3Config()
|
||||||
|
>>> # Initializing a model from the gemma3-4b style configuration
|
||||||
|
>>> model = Gemma3Model(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "gemma3_text"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 262_144,
|
||||||
|
hidden_size: int = 2304,
|
||||||
|
intermediate_size: int = 9216,
|
||||||
|
num_hidden_layers: int = 26,
|
||||||
|
num_attention_heads: int = 8,
|
||||||
|
num_key_value_heads: int = 4,
|
||||||
|
head_dim: int = 256,
|
||||||
|
sliding_window: int = 4096,
|
||||||
|
query_pre_attn_scalar: Optional[float] = 256,
|
||||||
|
rope_theta: float = 1_000_000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
rope_local_base_freq: float = 10_000.0,
|
||||||
|
sliding_window_pattern: int = 6,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
hidden_activation: str = "gelu_pytorch_tanh",
|
||||||
|
pad_token_id: int = 0,
|
||||||
|
eos_token_id: int = 1,
|
||||||
|
bos_token_id: int = 2,
|
||||||
|
tie_word_embeddings: bool = True,
|
||||||
|
max_position_embeddings: int = 131_072,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
use_cache: bool = True,
|
||||||
|
final_logit_softcapping=None,
|
||||||
|
attn_logit_softcapping=None,
|
||||||
|
cache_implementation: str = "hybrid",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.rope_local_base_freq = rope_local_base_freq
|
||||||
|
# For configuring HybridCache to work with 5:1 attention pattern
|
||||||
|
self.sliding_window_pattern = sliding_window_pattern
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
self.query_pre_attn_scalar = query_pre_attn_scalar
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.final_logit_softcapping = final_logit_softcapping
|
||||||
|
self.attn_logit_softcapping = attn_logit_softcapping
|
||||||
|
self.cache_implementation = cache_implementation
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
|
||||||
|
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||||
|
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
||||||
|
|
||||||
|
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
|
||||||
|
The config object of the text backbone.
|
||||||
|
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
||||||
|
Custom vision config or dict.
|
||||||
|
mm_tokens_per_image (`int`, *optional*, defaults to 256):
|
||||||
|
The number of tokens per image embedding.
|
||||||
|
boi_token_index (`int`, *optional*, defaults to 255999):
|
||||||
|
The begin-of-image token index to wrap the image prompt.
|
||||||
|
eoi_token_index (`int`, *optional*, defaults to 256000):
|
||||||
|
The end-of-image token index to wrap the image prompt.
|
||||||
|
image_token_index (`int`, *optional*, defaults to 262144):
|
||||||
|
The image token index to encode the image prompt.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Siglip-like vision config
|
||||||
|
>>> vision_config = SiglipVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3 Text config
|
||||||
|
>>> text_config = Gemma3TextConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3 gemma-3-4b style configuration
|
||||||
|
>>> configuration = Gemma3Config(vision_config, text_config)
|
||||||
|
|
||||||
|
>>> # Initializing a model from the gemma-3-4b style configuration
|
||||||
|
>>> model = Gemma3TextConfig(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "gemma3"
|
||||||
|
sub_configs = {
|
||||||
|
"text_config": Gemma3TextConfig,
|
||||||
|
"vision_config": SiglipVisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_config: Optional[Gemma3TextConfig] = None,
|
||||||
|
vision_config: Optional[SiglipVisionConfig] = None,
|
||||||
|
mm_tokens_per_image: int = 256,
|
||||||
|
boi_token_index: int = 255_999,
|
||||||
|
eoi_token_index: int = 256_000,
|
||||||
|
image_token_index: int = 262_144,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if text_config is None:
|
||||||
|
text_config = Gemma3TextConfig()
|
||||||
|
logger.info(
|
||||||
|
"text_config is None, using default Gemma3TextConfig vision config."
|
||||||
|
)
|
||||||
|
elif isinstance(text_config, dict):
|
||||||
|
text_config = Gemma3TextConfig(**text_config)
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
vision_config = SiglipVisionConfig(**vision_config)
|
||||||
|
else:
|
||||||
|
vision_config = SiglipVisionConfig()
|
||||||
|
logger.info(
|
||||||
|
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
|
||||||
|
"to text tasks."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.mm_tokens_per_image = mm_tokens_per_image
|
||||||
|
self.boi_token_index = boi_token_index
|
||||||
|
self.eoi_token_index = eoi_token_index
|
||||||
|
self.image_token_index = image_token_index
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Gemma3Config", "Gemma3TextConfig"]
|
@ -0,0 +1,463 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Image processor class for Gemma3."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.image_processing_utils import (
|
||||||
|
BaseImageProcessor,
|
||||||
|
BatchFeature,
|
||||||
|
get_size_dict,
|
||||||
|
)
|
||||||
|
from transformers.image_transforms import (
|
||||||
|
convert_to_rgb,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
|
from transformers.image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_size,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_scaled_image,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
validate_preprocess_arguments,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
TensorType,
|
||||||
|
filter_out_non_signature_kwargs,
|
||||||
|
is_vision_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .utils import make_nested_list_of_images
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a SigLIP image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||||
|
`do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||||
|
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
|
||||||
|
`do_normalize` in the `preprocess` method.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||||
|
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||||
|
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
do_pan_and_scan (`bool`, *optional*):
|
||||||
|
Whether to apply `pan_and_scan` to images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values", "num_crops"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = False,
|
||||||
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
do_convert_rgb: bool = None,
|
||||||
|
do_pan_and_scan: bool = None,
|
||||||
|
pan_and_scan_min_crop_size: int = None,
|
||||||
|
pan_and_scan_max_num_crops: int = None,
|
||||||
|
pan_and_scan_min_ratio_to_activate: float = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 224, "width": 224}
|
||||||
|
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
self.image_mean = image_mean
|
||||||
|
self.image_std = image_std
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
self.do_pan_and_scan = do_pan_and_scan
|
||||||
|
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
|
||||||
|
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
|
||||||
|
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
|
||||||
|
|
||||||
|
def pan_and_scan(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
pan_and_scan_min_crop_size: int,
|
||||||
|
pan_and_scan_max_num_crops: int,
|
||||||
|
pan_and_scan_min_ratio_to_activate: float,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Pan and Scan and image, whatever it means. TODO: write-up docs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
pan_and_scan_min_crop_size (`int`):
|
||||||
|
Size of pan_and_scan_min_crop_size.
|
||||||
|
pan_and_scan_max_num_crops (`int`):
|
||||||
|
pan_and_scan_max_num_crops for the image.
|
||||||
|
pan_and_scan_min_ratio_to_activate (`int`):
|
||||||
|
pan_and_scan_min_ratio_to_activate for the image..
|
||||||
|
data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||||
|
"""
|
||||||
|
height, width = get_image_size(image)
|
||||||
|
|
||||||
|
# Square or landscape image.
|
||||||
|
if width >= height:
|
||||||
|
# Only apply PaS if the image is sufficiently exaggerated
|
||||||
|
if width / height < pan_and_scan_min_ratio_to_activate:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||||
|
num_crops_w = int(
|
||||||
|
math.floor(width / height + 0.5)
|
||||||
|
) # Half round up rounding.
|
||||||
|
num_crops_w = min(
|
||||||
|
int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||||
|
num_crops_w = max(2, num_crops_w)
|
||||||
|
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
||||||
|
num_crops_h = 1
|
||||||
|
|
||||||
|
# Portrait image.
|
||||||
|
else:
|
||||||
|
# Only apply PaS if the image is sufficiently exaggerated
|
||||||
|
if height / width < pan_and_scan_min_ratio_to_activate:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
|
||||||
|
num_crops_h = int(math.floor(height / width + 0.5))
|
||||||
|
num_crops_h = min(
|
||||||
|
int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
|
||||||
|
num_crops_h = max(2, num_crops_h)
|
||||||
|
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
||||||
|
num_crops_w = 1
|
||||||
|
|
||||||
|
crop_size_w = int(math.ceil(width / num_crops_w))
|
||||||
|
crop_size_h = int(math.ceil(height / num_crops_h))
|
||||||
|
|
||||||
|
# Don't apply PaS if crop size is too small.
|
||||||
|
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
||||||
|
return []
|
||||||
|
|
||||||
|
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
|
||||||
|
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
|
||||||
|
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
image_crops = [
|
||||||
|
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||||
|
for pos_h, pos_w in itertools.product(
|
||||||
|
crop_positions_h, crop_positions_w
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
image_crops = [
|
||||||
|
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
|
||||||
|
for pos_h, pos_w in itertools.product(
|
||||||
|
crop_positions_h, crop_positions_w
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return image_crops
|
||||||
|
|
||||||
|
def _process_images_for_pas(
|
||||||
|
self,
|
||||||
|
images: List[np.ndarray],
|
||||||
|
do_pan_and_scan: bool,
|
||||||
|
pan_and_scan_min_crop_size: int,
|
||||||
|
pan_and_scan_max_num_crops: int,
|
||||||
|
pan_and_scan_min_ratio_to_activate: float,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
pas_images_list = []
|
||||||
|
num_crops = []
|
||||||
|
for image in images:
|
||||||
|
pas_images = self.pan_and_scan(
|
||||||
|
image=image,
|
||||||
|
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||||
|
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||||
|
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
pas_images_list.extend([image] + pas_images)
|
||||||
|
num_crops.append(len(pas_images))
|
||||||
|
return pas_images_list, num_crops
|
||||||
|
|
||||||
|
@filter_out_non_signature_kwargs()
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_normalize: bool = None,
|
||||||
|
image_mean: Optional[Union[float, List[float]]] = None,
|
||||||
|
image_std: Optional[Union[float, List[float]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
do_convert_rgb: bool = None,
|
||||||
|
do_pan_and_scan: bool = None,
|
||||||
|
pan_and_scan_min_crop_size: int = None,
|
||||||
|
pan_and_scan_max_num_crops: int = None,
|
||||||
|
pan_and_scan_min_ratio_to_activate: float = None,
|
||||||
|
) -> PIL.Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||||
|
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`int`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image.
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||||
|
Whether to normalize the image.
|
||||||
|
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the channel dimension format of the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||||
|
from the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to apply `pan_and_scan` to images.
|
||||||
|
"""
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = (
|
||||||
|
rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
)
|
||||||
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||||
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
|
do_convert_rgb = (
|
||||||
|
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||||
|
)
|
||||||
|
do_pan_and_scan = (
|
||||||
|
do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan
|
||||||
|
)
|
||||||
|
pan_and_scan_min_crop_size = (
|
||||||
|
pan_and_scan_min_crop_size
|
||||||
|
if pan_and_scan_min_crop_size is not None
|
||||||
|
else self.pan_and_scan_min_crop_size
|
||||||
|
)
|
||||||
|
pan_and_scan_max_num_crops = (
|
||||||
|
pan_and_scan_max_num_crops
|
||||||
|
if pan_and_scan_max_num_crops is not None
|
||||||
|
else self.pan_and_scan_max_num_crops
|
||||||
|
)
|
||||||
|
pan_and_scan_min_ratio_to_activate = (
|
||||||
|
pan_and_scan_min_ratio_to_activate
|
||||||
|
if pan_and_scan_min_ratio_to_activate is not None
|
||||||
|
else self.pan_and_scan_min_ratio_to_activate
|
||||||
|
)
|
||||||
|
|
||||||
|
images_list = make_nested_list_of_images(images)
|
||||||
|
|
||||||
|
if not valid_images(images_list[0]):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_preprocess_arguments(
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
image_mean=image_mean,
|
||||||
|
image_std=image_std,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
)
|
||||||
|
if do_convert_rgb:
|
||||||
|
images_list = [
|
||||||
|
[convert_to_rgb(image) for image in images] for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images_list = [
|
||||||
|
[to_numpy_array(image) for image in images] for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_rescale and is_scaled_image(images_list[0][0]):
|
||||||
|
logger.warning_once(
|
||||||
|
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||||
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_data_format is None:
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
input_data_format = infer_channel_dimension_format(images_list[0][0])
|
||||||
|
|
||||||
|
if do_pan_and_scan:
|
||||||
|
images_list_and_num_crops = [
|
||||||
|
self._process_images_for_pas(
|
||||||
|
images=images,
|
||||||
|
do_pan_and_scan=do_pan_and_scan,
|
||||||
|
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
|
||||||
|
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
|
||||||
|
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
images_list = [images for images, _ in images_list_and_num_crops]
|
||||||
|
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
|
||||||
|
else:
|
||||||
|
num_crops = [[0] for images in images_list]
|
||||||
|
|
||||||
|
if do_resize:
|
||||||
|
height, width = size["height"], size["width"]
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
resize(
|
||||||
|
image=image,
|
||||||
|
size=(height, width),
|
||||||
|
resample=resample,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
self.rescale(
|
||||||
|
image=image,
|
||||||
|
scale=rescale_factor,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if do_normalize:
|
||||||
|
images_list = [
|
||||||
|
[
|
||||||
|
self.normalize(
|
||||||
|
image=image,
|
||||||
|
mean=image_mean,
|
||||||
|
std=image_std,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
for images in images_list
|
||||||
|
]
|
||||||
|
|
||||||
|
images = [
|
||||||
|
to_channel_dimension_format(
|
||||||
|
image, data_format, input_channel_dim=input_data_format
|
||||||
|
)
|
||||||
|
for images in images_list
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
data = {"pixel_values": images, "num_crops": num_crops}
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Gemma3ImageProcessor"]
|
@ -0,0 +1,206 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
from transformers.image_utils import ImageInput
|
||||||
|
from transformers.processing_utils import (
|
||||||
|
ImagesKwargs,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
)
|
||||||
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
from transformers.utils import to_py_obj
|
||||||
|
from text_generation_server.models.custom_modeling.gemma3.image_processing_gemma3 import (
|
||||||
|
Gemma3ImageProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers.image_utils import PILImageResampling
|
||||||
|
|
||||||
|
from .utils import make_nested_list_of_images
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ImagesKwargs(ImagesKwargs):
|
||||||
|
do_pan_and_scan: Optional[bool]
|
||||||
|
pan_and_scan_min_crop_size: Optional[int]
|
||||||
|
pan_and_scan_max_num_crops: Optional[int]
|
||||||
|
pan_and_scan_min_ratio_to_activate: Optional[float]
|
||||||
|
do_convert_rgb: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": False,
|
||||||
|
},
|
||||||
|
"images_kwargs": {
|
||||||
|
"do_pan_and_scan": False,
|
||||||
|
"pan_and_scan_min_crop_size": 256,
|
||||||
|
"pan_and_scan_max_num_crops": 4,
|
||||||
|
"pan_and_scan_min_ratio_to_activate": 1.2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3Processor(ProcessorMixin):
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["chat_template"]
|
||||||
|
# # image_processor_class = "Gemma3ImageProcessor"
|
||||||
|
image_processor_class = "AutoProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
num_mm_soft_tokens_per_image: int = 256,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
num_mm_soft_tokens_per_image = 256
|
||||||
|
chat_template = None
|
||||||
|
|
||||||
|
image_processor = Gemma3ImageProcessor(
|
||||||
|
image_mean=(127.5,) * 3,
|
||||||
|
image_std=(127.5,) * 3,
|
||||||
|
size={"height": 896, "width": 896},
|
||||||
|
do_rescale=False,
|
||||||
|
resample=PILImageResampling.BILINEAR,
|
||||||
|
)
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
self.image_token_id = tokenizer.image_token_id
|
||||||
|
image_tokens_expanded = "".join(
|
||||||
|
[tokenizer.image_token] * num_mm_soft_tokens_per_image
|
||||||
|
)
|
||||||
|
self.full_image_sequence = (
|
||||||
|
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
self.image_processor = image_processor
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.chat_template = chat_template
|
||||||
|
|
||||||
|
# super().__init__(
|
||||||
|
# image_processor=image_processor,
|
||||||
|
# tokenizer=tokenizer,
|
||||||
|
# chat_template=chat_template,
|
||||||
|
# **kwargs,
|
||||||
|
# )
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: ImageInput = None,
|
||||||
|
text: Union[
|
||||||
|
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
||||||
|
] = None,
|
||||||
|
videos=None,
|
||||||
|
audio=None,
|
||||||
|
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
if text is None and images is None:
|
||||||
|
raise ValueError("Provide at least one of `text` or `images`.")
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
Gemma3ProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid input text. Please provide a string, or a list of strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs = {}
|
||||||
|
if images is not None:
|
||||||
|
batched_images = make_nested_list_of_images(images)
|
||||||
|
image_inputs = self.image_processor(
|
||||||
|
batched_images, **output_kwargs["images_kwargs"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create empty text to be replaced with placeholders
|
||||||
|
if not text:
|
||||||
|
text = [
|
||||||
|
" ".join(["<image>"] * len(images)) for images in batched_images
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(batched_images) != len(text):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace image tokens by the full expanded sequence
|
||||||
|
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
||||||
|
for prompt, images, num_crops in zip(text, batched_images, batch_num_crops):
|
||||||
|
image_indexes = [m.start() for m in re.finditer("<image>", prompt)]
|
||||||
|
|
||||||
|
if len(images) != len(image_indexes):
|
||||||
|
raise ValueError(
|
||||||
|
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert additional image tokens for Pan-and-Scan crops
|
||||||
|
for num, idx in reversed(list(zip(num_crops, image_indexes))):
|
||||||
|
if num:
|
||||||
|
formatted_image_text = (
|
||||||
|
"Here is the original image <image> and here are some crops to help you see better "
|
||||||
|
+ " ".join(["<image>"] * num)
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
prompt[:idx]
|
||||||
|
+ formatted_image_text
|
||||||
|
+ prompt[idx + len("<image>") :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand placeholder image tokens to the full image token sequence
|
||||||
|
text = [
|
||||||
|
prompt.replace("<image>", self.full_image_sequence) for prompt in text
|
||||||
|
]
|
||||||
|
|
||||||
|
text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||||
|
return BatchFeature(data={**text_input, **image_inputs})
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
refer to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||||
|
the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Gemma3Processor"]
|
@ -0,0 +1,61 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.image_utils import ImageInput, is_valid_image, is_pil_image
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_list_of_images(images: List):
|
||||||
|
return images and all(is_valid_image(image) for image in images)
|
||||||
|
|
||||||
|
|
||||||
|
def make_nested_list_of_images(
|
||||||
|
images: Union[List[ImageInput], ImageInput],
|
||||||
|
) -> ImageInput:
|
||||||
|
"""
|
||||||
|
Ensure that the output is a nested list of images.
|
||||||
|
Args:
|
||||||
|
images (`Union[List[ImageInput], ImageInput]`):
|
||||||
|
The input image.
|
||||||
|
Returns:
|
||||||
|
list: A list of list of images or a list of 4d array of images.
|
||||||
|
"""
|
||||||
|
# If it's a list of batches, it's already in the right format
|
||||||
|
if (
|
||||||
|
isinstance(images, (list, tuple))
|
||||||
|
and all(isinstance(images_i, (list, tuple)) for images_i in images)
|
||||||
|
and all(is_valid_list_of_images(images_i) for images_i in images)
|
||||||
|
):
|
||||||
|
return images
|
||||||
|
|
||||||
|
# If it's a list of images, it's a single batch, so convert it to a list of lists
|
||||||
|
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
|
||||||
|
if is_pil_image(images[0]) or images[0].ndim == 3:
|
||||||
|
return [images]
|
||||||
|
if images[0].ndim == 4:
|
||||||
|
return [list(image) for image in images]
|
||||||
|
|
||||||
|
# If it's a single image, convert it to a list of lists
|
||||||
|
if is_valid_image(images):
|
||||||
|
if is_pil_image(images) or images.ndim == 3:
|
||||||
|
return [[images]]
|
||||||
|
if images.ndim == 4:
|
||||||
|
return [list(images)]
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
|
||||||
|
)
|
@ -23,6 +23,13 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FlashGemma2ForCausalLM(prefix, config, weights)
|
return FlashGemma2ForCausalLM(prefix, config, weights)
|
||||||
|
|
||||||
|
elif config.model_type == "gemma3" or config.model_type == "gemma3_text":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
|
||||||
|
FlashGemma3ForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashGemma3ForCausalLM(prefix, config, weights)
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
@ -42,13 +49,21 @@ def load_vision_model(prefix, config, weights):
|
|||||||
return CLIPVisionTransformer(
|
return CLIPVisionTransformer(
|
||||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||||
)
|
)
|
||||||
if config.model_type == "siglip_vision_model":
|
if (
|
||||||
|
config.model_type == "siglip_vision_model"
|
||||||
|
or config.model_type == "gemma3_vision"
|
||||||
|
):
|
||||||
from text_generation_server.models.custom_modeling.siglip import (
|
from text_generation_server.models.custom_modeling.siglip import (
|
||||||
SiglipVisionTransformer,
|
SiglipVisionTransformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: ensure that using the prefix doesn't break any existing models
|
||||||
|
# that rely on the old prefix (update the old models if necessary)
|
||||||
return SiglipVisionTransformer(
|
return SiglipVisionTransformer(
|
||||||
prefix="vision_tower.vision_model", config=config, weights=weights
|
# prefix="vision_model.vision_model", config=config, weights=weights
|
||||||
|
prefix=f"{prefix}.vision_model",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
@ -128,6 +128,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
elif config.model_type == "gemma3":
|
||||||
|
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||||
|
# and calculating the number of image tokens
|
||||||
|
num_pads = 256
|
||||||
|
padding = "<image_soft_token>" * num_pads
|
||||||
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -244,6 +250,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
if config.model_type == "llava_next":
|
if config.model_type == "llava_next":
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
elif config.model_type == "gemma3":
|
||||||
|
images.append(image)
|
||||||
else:
|
else:
|
||||||
images.append([image])
|
images.append([image])
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user