Add gemma3 model

This commit is contained in:
Mohit Sharma 2025-03-12 07:01:46 +00:00
parent ae4451c3da
commit 587e5dea22
21 changed files with 3145 additions and 9 deletions

View File

@ -14,6 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
- [Gemma3](https://huggingface.co/collections/google/gemma-3)
- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)

View File

@ -0,0 +1,133 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 236764,
"logprob": -0.44726562,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.011413574,
"special": false,
"text": " "
},
{
"id": 236812,
"logprob": -0.09814453,
"special": false,
"text": "4"
},
{
"id": 236764,
"logprob": -0.044189453,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.15625,
"special": false,
"text": " "
},
{
"id": 236810,
"logprob": -0.010864258,
"special": false,
"text": "5"
},
{
"id": 236764,
"logprob": -0.040039062,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.26757812,
"special": false,
"text": " "
},
{
"id": 236825,
"logprob": -0.0047302246,
"special": false,
"text": "6"
},
{
"id": 236764,
"logprob": -0.026123047,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.265625,
"special": false,
"text": " "
},
{
"id": 236832,
"logprob": -0.014160156,
"special": false,
"text": "7"
},
{
"id": 236764,
"logprob": -0.013977051,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.103515625,
"special": false,
"text": " "
},
{
"id": 236828,
"logprob": -0.008178711,
"special": false,
"text": "8"
},
{
"id": 236764,
"logprob": -0.030151367,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.39453125,
"special": false,
"text": " "
},
{
"id": 236819,
"logprob": -0.008728027,
"special": false,
"text": "9"
},
{
"id": 236764,
"logprob": -0.020629883,
"special": false,
"text": ","
},
{
"id": 236743,
"logprob": -0.08154297,
"special": false,
"text": " "
}
],
"top_tokens": null
},
"generated_text": ", 4, 5, 6, 7, 8, 9, "
}

View File

@ -0,0 +1,613 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 100,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 1331,
"logprob": -0.32421875,
"special": false,
"text": " people"
},
{
"id": 8390,
"logprob": -0.15332031,
"special": false,
"text": " died"
},
{
"id": 528,
"logprob": -1.140625,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -0.42578125,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.64453125,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": -0.0027770996,
"special": false,
"text": " States"
},
{
"id": 236761,
"logprob": -0.37890625,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.08300781,
"special": false,
"text": "\n\n"
},
{
"id": 818,
"logprob": -1.1796875,
"special": false,
"text": "The"
},
{
"id": 6816,
"logprob": -1.765625,
"special": false,
"text": " generally"
},
{
"id": 10951,
"logprob": -0.14550781,
"special": false,
"text": " accepted"
},
{
"id": 10967,
"logprob": -0.90625,
"special": false,
"text": " estimate"
},
{
"id": 563,
"logprob": -0.49414062,
"special": false,
"text": " is"
},
{
"id": 600,
"logprob": -0.65625,
"special": false,
"text": " that"
},
{
"id": 236743,
"logprob": -1.1796875,
"special": false,
"text": " "
},
{
"id": 236825,
"logprob": -0.0009918213,
"special": false,
"text": "6"
},
{
"id": 236832,
"logprob": -6.532669e-05,
"special": false,
"text": "7"
},
{
"id": 236810,
"logprob": -4.863739e-05,
"special": false,
"text": "5"
},
{
"id": 236764,
"logprob": -0.00017929077,
"special": false,
"text": ","
},
{
"id": 236771,
"logprob": -1.2397766e-05,
"special": false,
"text": "0"
},
{
"id": 236771,
"logprob": -2.1457672e-06,
"special": false,
"text": "0"
},
{
"id": 236771,
"logprob": 0.0,
"special": false,
"text": "0"
},
{
"id": 1331,
"logprob": -0.50390625,
"special": false,
"text": " people"
},
{
"id": 8390,
"logprob": -0.011474609,
"special": false,
"text": " died"
},
{
"id": 528,
"logprob": -0.08496094,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -0.0003299713,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.028442383,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": -0.00011014938,
"special": false,
"text": " States"
},
{
"id": 236761,
"logprob": -1.1796875,
"special": false,
"text": "."
},
{
"id": 3153,
"logprob": -0.104003906,
"special": false,
"text": " However"
},
{
"id": 236764,
"logprob": -0.009094238,
"special": false,
"text": ","
},
{
"id": 1070,
"logprob": -0.88671875,
"special": false,
"text": " some"
},
{
"id": 61806,
"logprob": -0.84765625,
"special": false,
"text": " historians"
},
{
"id": 4646,
"logprob": -1.34375,
"special": false,
"text": " believe"
},
{
"id": 506,
"logprob": -0.59375,
"special": false,
"text": " the"
},
{
"id": 5396,
"logprob": -0.8046875,
"special": false,
"text": " actual"
},
{
"id": 1548,
"logprob": -0.04321289,
"special": false,
"text": " number"
},
{
"id": 1451,
"logprob": -0.60546875,
"special": false,
"text": " could"
},
{
"id": 577,
"logprob": -0.091308594,
"special": false,
"text": " be"
},
{
"id": 618,
"logprob": -0.61328125,
"special": false,
"text": " as"
},
{
"id": 1494,
"logprob": -0.00033569336,
"special": false,
"text": " high"
},
{
"id": 618,
"logprob": -0.0001411438,
"special": false,
"text": " as"
},
{
"id": 236743,
"logprob": -0.001045227,
"special": false,
"text": " "
},
{
"id": 236770,
"logprob": -0.21289062,
"special": false,
"text": "1"
},
{
"id": 236771,
"logprob": -0.13378906,
"special": false,
"text": "0"
},
{
"id": 3625,
"logprob": -0.0087890625,
"special": false,
"text": " million"
},
{
"id": 236761,
"logprob": -0.2109375,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.39453125,
"special": false,
"text": "\n\n"
},
{
"id": 236777,
"logprob": -1.1328125,
"special": false,
"text": "I"
},
{
"id": 1006,
"logprob": -1.4140625,
"special": false,
"text": " am"
},
{
"id": 3182,
"logprob": -1.15625,
"special": false,
"text": " looking"
},
{
"id": 573,
"logprob": -0.035888672,
"special": false,
"text": " for"
},
{
"id": 919,
"logprob": -1.2734375,
"special": false,
"text": " more"
},
{
"id": 1938,
"logprob": -1.2265625,
"special": false,
"text": " information"
},
{
"id": 580,
"logprob": -0.7734375,
"special": false,
"text": " on"
},
{
"id": 672,
"logprob": -0.77734375,
"special": false,
"text": " this"
},
{
"id": 59725,
"logprob": -0.70703125,
"special": false,
"text": " discrepancy"
},
{
"id": 532,
"logprob": -0.8515625,
"special": false,
"text": " and"
},
{
"id": 506,
"logprob": -0.65625,
"special": false,
"text": " the"
},
{
"id": 5872,
"logprob": -1.15625,
"special": false,
"text": " factors"
},
{
"id": 600,
"logprob": -0.2265625,
"special": false,
"text": " that"
},
{
"id": 19263,
"logprob": -1.125,
"special": false,
"text": " contributed"
},
{
"id": 531,
"logprob": -0.001083374,
"special": false,
"text": " to"
},
{
"id": 506,
"logprob": -0.2109375,
"special": false,
"text": " the"
},
{
"id": 5777,
"logprob": -1.21875,
"special": false,
"text": " wide"
},
{
"id": 2644,
"logprob": -0.018310547,
"special": false,
"text": " range"
},
{
"id": 529,
"logprob": -0.12988281,
"special": false,
"text": " of"
},
{
"id": 14287,
"logprob": -0.03564453,
"special": false,
"text": " estimates"
},
{
"id": 236761,
"logprob": -0.010314941,
"special": false,
"text": "."
},
{
"id": 108,
"logprob": -0.060546875,
"special": false,
"text": "\n\n"
},
{
"id": 8291,
"logprob": -0.734375,
"special": false,
"text": "Here"
},
{
"id": 236789,
"logprob": -0.26367188,
"special": false,
"text": "'"
},
{
"id": 236751,
"logprob": -1.1920929e-06,
"special": false,
"text": "s"
},
{
"id": 496,
"logprob": -0.15527344,
"special": false,
"text": " a"
},
{
"id": 25890,
"logprob": -0.08886719,
"special": false,
"text": " breakdown"
},
{
"id": 529,
"logprob": -0.0020446777,
"special": false,
"text": " of"
},
{
"id": 506,
"logprob": -0.17871094,
"special": false,
"text": " the"
},
{
"id": 5872,
"logprob": -0.90234375,
"special": false,
"text": " factors"
},
{
"id": 20894,
"logprob": -0.25976562,
"special": false,
"text": " contributing"
},
{
"id": 531,
"logprob": -8.34465e-05,
"special": false,
"text": " to"
},
{
"id": 506,
"logprob": -0.008544922,
"special": false,
"text": " the"
},
{
"id": 5777,
"logprob": -0.62109375,
"special": false,
"text": " wide"
},
{
"id": 2644,
"logprob": -0.0023345947,
"special": false,
"text": " range"
},
{
"id": 529,
"logprob": -0.016723633,
"special": false,
"text": " of"
},
{
"id": 14287,
"logprob": -0.011291504,
"special": false,
"text": " estimates"
},
{
"id": 573,
"logprob": -0.29101562,
"special": false,
"text": " for"
},
{
"id": 506,
"logprob": -0.21484375,
"special": false,
"text": " the"
},
{
"id": 236743,
"logprob": -0.2890625,
"special": false,
"text": " "
},
{
"id": 236770,
"logprob": -3.5762787e-07,
"special": false,
"text": "1"
},
{
"id": 236819,
"logprob": 0.0,
"special": false,
"text": "9"
},
{
"id": 236770,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 236828,
"logprob": 0.0,
"special": false,
"text": "8"
},
{
"id": 7745,
"logprob": -0.70703125,
"special": false,
"text": " flu"
},
{
"id": 10248,
"logprob": -0.01953125,
"special": false,
"text": " pandemic"
},
{
"id": 4355,
"logprob": -0.78515625,
"special": false,
"text": " death"
},
{
"id": 25363,
"logprob": -6.771088e-05,
"special": false,
"text": " toll"
},
{
"id": 528,
"logprob": -0.08496094,
"special": false,
"text": " in"
},
{
"id": 506,
"logprob": -7.033348e-06,
"special": false,
"text": " the"
},
{
"id": 3640,
"logprob": -0.0067443848,
"special": false,
"text": " United"
},
{
"id": 4184,
"logprob": 0.0,
"special": false,
"text": " States"
}
],
"top_tokens": null
},
"generated_text": " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a humorous and unexpected sight of a cow enjoying a tropical beach!",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741703756,
"id": "",
"model": "gg-hf-g/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.1.2-dev0-native",
"usage": {
"completion_tokens": 70,
"prompt_tokens": 277,
"total_tokens": 347
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "Based on the image, the animal is a cow, not a dog! \n\nIt appears to be a **Brazilian cattle breed** known as a **Gir Cow**. They are recognized for their reddish-brown color and distinctive markings.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1741703753,
"id": "",
"model": "gg-hf-g/gemma-3-4b-it",
"object": "chat.completion",
"system_fingerprint": "3.1.2-dev0-native",
"usage": {
"completion_tokens": 48,
"prompt_tokens": 281,
"total_tokens": 329
}
}

View File

@ -0,0 +1,90 @@
import pytest
@pytest.fixture(scope="module")
def flash_gemma3_handle(launcher):
with launcher("gg-hf-g/gemma-3-4b-it", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma3(flash_gemma3_handle):
await flash_gemma3_handle.health(300)
return flash_gemma3_handle.client
async def test_flash_gemma3(flash_gemma3, response_snapshot):
response = await flash_gemma3.generate(
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
seed=42,
max_new_tokens=100,
)
assert (
response.generated_text
== " people died in the United States.\n\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\n\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\n\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States"
)
assert response.details.generated_tokens == 100
assert response == response_snapshot
async def test_flash_gemma3_image_cow_dog(flash_gemma3, response_snapshot):
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{
"type": "text",
"text": "What is the breed of the dog in the image?",
},
],
},
],
max_tokens=100,
)
assert (
response.choices[0].message.content
== "Based on the image, the animal is a cow, not a dog! \n\nIt appears to be a **Brazilian cattle breed** known as a **Gir Cow**. They are recognized for their reddish-brown color and distinctive markings."
)
assert response.usage["completion_tokens"] == 48
assert response == response_snapshot
async def test_flash_gemma3_image_cow(flash_gemma3, response_snapshot):
image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = await flash_gemma3.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "What is shown in this image?"},
],
},
],
max_tokens=100,
)
assert (
response.choices[0].message.content
== "Here's a description of what's shown in the image:\n\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \n\nIt's a humorous and unexpected sight of a cow enjoying a tropical beach!"
)
assert response.usage["completion_tokens"] == 70
assert response == response_snapshot
async def test_exceed_window(flash_gemma3, response_snapshot):
response = await flash_gemma3.generate(
"This is a nice place. " * 800 + "Now count: 1, 2, 3",
seed=42,
max_new_tokens=20,
)
assert response.generated_text == ", 4, 5, 6, 7, 8, 9, "
assert response.details.generated_tokens == 20
assert response == response_snapshot

View File

@ -2064,6 +2064,7 @@ fn main() -> Result<(), LauncherError> {
let default_optimal = match config {
Some(ref config) => match config.model_type.as_deref() {
Some("qwen2_vl") | Some("qwen2_5_vl") => 10_000,
Some("gemma3") => 8000,
_ => 4096,
},
None => 4096,

View File

@ -216,6 +216,19 @@ impl Qwen2_5Vl {
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Gemma3VisionConfig {
pub(crate) image_size: usize,
pub(crate) patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Gemma3 {
vision_config: Gemma3VisionConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
@ -249,6 +262,7 @@ pub enum Config {
Paligemma(Paligemma),
Gemma,
Gemma2,
Gemma3(Gemma3),
Cohere,
Drbx,
Falcon,

View File

@ -33,7 +33,16 @@ impl ChatTemplate {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
// TODO: replace with better solution
// hack to adjust gemma3 template for debug
// replace 'messages[0]['content'][0]['text']' with 'messages[0]['content']'
let mutated_template = template.replace(
"messages[0]['content'][0]['text']",
"messages[0]['content']",
);
let template_str = mutated_template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
tracing::debug!("Loading template: {}", template_str);
@ -123,8 +132,8 @@ mod tests {
use crate::infer::chat_template::{raise_exception, strftime_now};
use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, Message, MessageBody, MessageContent, TextMessage,
TokenizerConfigToken, Tool,
ChatTemplateInputs, Message, MessageBody, MessageChunk, MessageContent, TextMessage,
TokenizerConfigToken, Tool, Url,
};
use chrono::Local;
use minijinja::Environment;
@ -1230,4 +1239,98 @@ TOOL CALL ID: 0
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": \"{\\\"type\\\":\\\"object\\\",\\\"properties\\\":{\\\"location\\\":{\\\"type\\\":\\\"string\\\",\\\"description\\\":\\\"The city and state, e.g. San Francisco, CA\\\"},\\\"format\\\":{\\\"type\\\":\\\"string\\\",\\\"enum\\\":[\\\"celsius\\\",\\\"fahrenheit\\\"],\\\"description\\\":\\\"The temperature unit to use. Infer this from the users location.\\\"}},\\\"required\\\":[\\\"location\\\",\\\"format\\\"]}\",\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_chat_template_with_special_system_prompt() {
// chat template from gemma3
let ct = ChatTemplate::new(
r#"{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
' -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>
' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model
'}}
{%- endif -%}
"#
.to_string(),
Some(TokenizerConfigToken::String("<bos>".to_string())),
Some(TokenizerConfigToken::String("</eos>".to_string())),
);
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "system".to_string(),
body: MessageBody::Content {
content: MessageContent::MultipleChunks(vec![MessageChunk::Text {
text: "You are a helpful assistant.".to_string(),
}]),
},
},
Message {
name: None,
role: "user".to_string(),
body: MessageBody::Content {
content: MessageContent::MultipleChunks(vec![
MessageChunk::Text {
text: "I'm already using this supplement ".to_string(),
},
MessageChunk::ImageUrl {
image_url: Url {
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG".to_string()
},
},
MessageChunk::Text {
text: "and I want to use this one too ".to_string()
},
MessageChunk::ImageUrl {
image_url: Url {
url: "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg".to_string()
},
},
MessageChunk::Text {
text: " what are cautions?".to_string()
},
]),
},
},
];
let result = ct.apply(msgs, None);
let expected = "<bos><start_of_turn>user\nYou are a helpful assistant.\n\nI'm already using this supplement ![](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG)and I want to use this one too ![](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg) what are cautions?<end_of_turn>\n<start_of_turn>model\n".to_string();
assert_eq!(result.unwrap(), expected);
}
}

View File

@ -150,6 +150,11 @@ impl HubTokenizerConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ChatTemplateStandalone {
pub chat_template: ChatTemplateVersions,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum TokenizerConfigToken {
@ -171,6 +176,7 @@ impl TokenizerConfigToken {
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor),
Gemma3Processor(Gemma3Processor),
}
impl HubPreprocessorConfig {
@ -186,6 +192,12 @@ pub struct Idefics2Preprocessor {
do_image_splitting: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Gemma3Processor {
#[serde(default)]
do_image_splitting: bool,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>,

View File

@ -1781,6 +1781,7 @@ pub async fn run(
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
chat_template_filename,
model_info,
) = match api {
Type::None => (
@ -1788,6 +1789,7 @@ pub async fn run(
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
Some(local_path.join("chat_template.json")),
None,
),
Type::Api(api) => {
@ -1801,6 +1803,7 @@ pub async fn run(
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let chat_template_filename = api_repo.get("chat_template.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
@ -1813,6 +1816,7 @@ pub async fn run(
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
chat_template_filename,
model_info,
)
}
@ -1827,11 +1831,23 @@ pub async fn run(
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
repo.get("chat_template.json"),
None,
)
}
};
// if chat_template_filename is present, load the chat template
let chat_template: Option<crate::ChatTemplateVersions> = chat_template_filename
.and_then(|f| std::fs::read_to_string(f).ok())
.and_then(|c| {
let res = serde_json::from_str::<crate::ChatTemplateStandalone>(&c);
if let Err(e) = &res {
tracing::warn!("Could not parse chat template {e:?}");
}
res.ok().map(|t| t.chat_template)
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
@ -1839,11 +1855,16 @@ pub async fn run(
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
let mut tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
if chat_template.is_some() {
tracing::info!("Using chat template from chat_template.json");
tokenizer_config.chat_template = chat_template;
}
let tokenizer: Result<Tokenizer, WebServerError> = {
use pyo3::prelude::*;
Python::with_gil(|py| -> PyResult<()> {

View File

@ -18,6 +18,7 @@ use std::sync::Arc;
use thiserror::Error;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::warn;
use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex};
@ -694,6 +695,14 @@ fn image_tokens(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
Gemma3(_config) => {
// TODO: prefer using the config to determine the number of features
let num_mm_soft_tokens_per_image = 256;
format!(
"\n\n<start_of_image>{:?}<end_of_image>\n\n",
"<image_soft_token>".repeat(num_mm_soft_tokens_per_image)
)
}
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
@ -721,8 +730,8 @@ fn prepare_input<T: TokenizerTrait>(
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
| Qwen2Vl(_) | Qwen2_5Vl(_)),
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Paligemma(_)
| LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),
) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -106,6 +106,17 @@ try:
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
FlashGemma3ForCausalLM,
Gemma3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.gemma3.processing_gemma3 import (
Gemma3Processor,
)
from text_generation_server.models.custom_modeling.gemma3.configuration_gemma3 import (
Gemma3Config,
Gemma3TextConfig,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
@ -258,6 +269,16 @@ class ModelType(enum.Enum):
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
GEMMA3 = {
"type": "gemma3",
"name": "Gemma3",
"url": "https://huggingface.co/collections/google/gemma-3",
}
GEMMA3_TEXT = {
"type": "gemma3_text",
"name": "Gemma3 Text",
"url": "https://huggingface.co/collections/google/gemma-3",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
@ -1094,6 +1115,83 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA3_TEXT:
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# TODO: once implemented in transformers, use the config class
# and processor class from there.
config_class=Gemma3TextConfig,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA3:
if FLASH_ATTENTION:
# TODO: Use VlmCausalLM when image support is added.
return VlmCausalLM(
model_id=model_id,
model_class=Gemma3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# TODO: once implemented in transformers, use the config class
# and processor class from there.
config_class=Gemma3Config,
processor_class=Gemma3Processor,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == COHERE:
if FLASH_ATTENTION:

View File

@ -0,0 +1,922 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from typing import Optional, List, Tuple
import copy
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
#
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
import torch
import torch.nn.functional as F
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
)
ATTENTION_TYPE_GLOBAL = "global"
ATTENTION_TYPE_LOCAL = "local_sliding"
class Gemma3FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma3Attention(torch.nn.Module):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
# TODO: remove this hack to support local sliding window
config = copy.deepcopy(config)
config.rope_scaling = dict(rope_type="default")
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_local_base_freq,
device=weights.device,
)
else:
self.window_size = -1
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = (
config.query_pre_attn_scalar**-0.5
if config.query_pre_attn_scalar is not None
else None
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.softcap = None # config.attn_logit_softcapping
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.q_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.enable_gqa = self.num_heads != self.num_key_value_heads
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
key = kv[:, 0]
value = kv[:, 1]
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query, _ = self.q_norm(query.contiguous())
key, _ = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
if attention_mask is None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
)
else:
lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]
# Split tensors using vectorized split
query_list = torch.split(query, lengths.tolist(), dim=0)
key_list = torch.split(key, lengths.tolist(), dim=0)
value_list = torch.split(value, lengths.tolist(), dim=0)
padded_query = torch.nn.utils.rnn.pad_sequence(
query_list, batch_first=True
)
padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)
padded_value = torch.nn.utils.rnn.pad_sequence(
value_list, batch_first=True
)
padded_query = padded_query.transpose(1, 2).contiguous()
padded_key = padded_key.transpose(1, 2).contiguous()
padded_value = padded_value.transpose(1, 2).contiguous()
zeros_to_add = torch.zeros(
padded_key.size(0),
self.num_key_value_heads,
1,
self.head_size,
dtype=padded_key.dtype,
device=padded_key.device,
)
key_states = torch.cat([padded_key, zeros_to_add], dim=2)
value_states = torch.cat([padded_value, zeros_to_add], dim=2)
# Compute attention
attn_output = F.scaled_dot_product_attention(
padded_query,
key_states,
value_states,
attn_mask=attention_mask,
scale=self.softmax_scale,
enable_gqa=self.enable_gqa,
)
attn_output = attn_output.transpose(
1, 2
) # [batch_size, seq_len, num_heads, head_dim]
max_seq_len = padded_query.size(2)
seq_range = torch.arange(
max_seq_len, device=padded_query.device
).unsqueeze(0)
lengths_tensor = torch.tensor(
lengths, device=padded_query.device
).unsqueeze(1)
mask = seq_range < lengths_tensor # [batch, max_seq_len]
attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim]
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
softcap=self.softcap,
kv_scales=self.kv_scales,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma3MLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma3Layer(nn.Module):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.self_attn = FlashGemma3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
)
self.mlp = Gemma3MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGemma3Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
residual = None
for i, layer in enumerate(self.layers):
cos, sin = self.layers[i].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
# apply sliding window mask if needed
if layer.self_attn.window_size > 0 and attention_mask is not None:
min_dtype = torch.finfo(hidden_states.dtype).min
# prefill may be larger than sliding window
effective_seq_len = max(
position_ids.shape[0], self.layers[i].self_attn.window_size
)
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.layers[i].self_attn.window_size,
)
attention_mask = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
offset = max(0, position_ids.shape[0] - effective_seq_len)
attention_mask = attention_mask[
:, :, offset : offset + effective_seq_len
]
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
adapter_data,
attention_mask,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma3Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
# self.softcap = config.attn_logit_softcapping
# assert isinstance(self.softcap, float)
self.softcap = None
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
class Gemma3MultimodalInputProjection(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.mm_input_projection_weight = weights.get_tensor(
"multi_modal_projector.mm_input_projection_weight"
)
self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.mm_soft_emb_norm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
if config.vision_config is not None:
config.vision_config.quantize = config.quantize
self.post_vision_model_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multimodal_projector = Gemma3MultimodalInputProjection(
prefix="multi_modal_projector",
config=config,
weights=weights,
)
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.vision_model = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
else:
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix=prefix,
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def get_image_token_mask(self, input_ids):
device = input_ids.device
start_token_id = self.config.boi_token_index
K = self.config.mm_tokens_per_image
mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device)
start_positions = (input_ids == start_token_id).nonzero(as_tuple=True)[0]
mask_indices = start_positions.unsqueeze(1) + torch.arange(
1, K + 1, device=device
).unsqueeze(0)
valid_mask = mask_indices < input_ids.size(0)
mask_indices = mask_indices[valid_mask]
mask[mask_indices] = True
return mask
def get_attention_mask(
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
):
device = input_ids.device
min_dtype = torch.finfo(dtype).min
lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
batch_size = len(lengths)
sequence_length = max(lengths)
target_length = max_s
# Create the padding mask from the computed lengths.
# pad_mask: [batch, sequence_length] where True indicates valid tokens.
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)
pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
# Build the base causal mask (for non-image tokens):
causal_mask = torch.tril(
torch.ones(
(sequence_length, sequence_length), dtype=torch.bool, device=device
)
)
base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
1
) # [batch, sequence_length, sequence_length]
base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint
image_token_mask = torch.nn.utils.rnn.pad_sequence(
torch.split(image_token_mask, lengths), batch_first=True, padding_value=0
)
bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(
1
)
# Combine the causal base mask and the bidirectional mask.
combined_mask = torch.logical_or(
base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)
).to(device)
# combined_mask now has shape [batch, 1, sequence_length, sequence_length]
full_attention_mask = torch.zeros(
(batch_size, 1, sequence_length, target_length),
device=device,
dtype=torch.bool,
)
full_attention_mask[:, :, :, :sequence_length] = combined_mask
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
return final_attention_mask
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused here
attention_mask: Optional[torch.BoolTensor] = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = image_features.view(
-1, image_features.shape[-1]
)
attention_mask = self.get_attention_mask(
input_ids,
max_s,
cu_seqlen_prefill,
inputs_embeds.dtype,
image_token_mask,
)
# Use flash attention for text-only input
# else:
# if cu_seqlen_prefill is not None:
# min_dtype = torch.finfo(inputs_embeds.dtype).min
# lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()
# # Determine the maximum sequence length (after padding) from query.
# sequence_length = max(lengths)
# target_length = max_s
# # Create the padding mask from the computed lengths.
# # pad_mask: [batch, sequence_length] where True indicates valid tokens.
# seq_range = torch.arange(
# sequence_length, device=input_ids.device
# ).unsqueeze(0)
# lengths_tensor = torch.tensor(
# lengths, device=input_ids.device
# ).unsqueeze(1)
# pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length]
# # Build the base causal mask (for non-image tokens):
# causal_mask = torch.tril(
# torch.ones(
# (sequence_length, sequence_length),
# dtype=torch.bool,
# device=input_ids.device,
# )
# )
# base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(
# 1
# ) # [batch, sequence_length, sequence_length]
# base_mask = base_mask & causal_mask.unsqueeze(0)
# attention_mask = base_mask.unsqueeze(
# 1
# ) # [batch, 1, sequence_length, sequence_length]
# full_attention_mask = torch.zeros(
# (len(lengths), 1, sequence_length, target_length),
# device=input_ids.device,
# dtype=torch.bool,
# )
# full_attention_mask[:, :, :, :sequence_length] = attention_mask
# attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(
# input_ids.device
# )
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
attention_mask=attention_mask,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
# pad logit with 1 zero logit for the image token
if pixel_values is not None:
logits = torch.cat(
[logits, torch.zeros(logits.size(0), 1, device=logits.device)], dim=1
)
if speculative_logits is not None:
speculative_logits = torch.cat(
[
speculative_logits,
torch.zeros(
speculative_logits.size(0),
1,
device=speculative_logits.device,
),
],
dim=1,
)
return logits, speculative_logits

View File

@ -31,7 +31,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
prefix="vision_model" if not prefix else f"{prefix}.vision_model",
config=config.vision_config,
weights=weights,
)

View File

@ -0,0 +1,313 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers import SiglipVisionConfig
logger = logging.get_logger(__name__)
class Gemma3TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma3-4B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262144):
Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Gemma3Model`]
hidden_size (`int`, *optional*, defaults to 2304):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 9216):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 26):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window
attention. This is the size of the sliding window.
query_pre_attn_scalar (`float`, *optional*):
The scaling factor used on the attention scores, not that
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings used for global attention.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
sliding_window_pattern (`int`, *optional*, defaults to 6):
Pattern for the sliding window attention.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
activation function.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
final_logit_softcapping (`bool`, *optional*, defaults to `True`):
Whether to apply logit softcapping or nor
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
Scaling factor when applying tanh soft-capping on the attention scorexs.
cache_implementation (`str`, *optional*, defaults to `"hybrid"`):
The cache type to be used with `generate`.
```python
>>> from transformers import Gemma3Model, Gemma3TextConfig
>>> # Initializing a Gemma3 gemma3-4b style configuration
>>> configuration = Gemma3Config()
>>> # Initializing a model from the gemma3-4b style configuration
>>> model = Gemma3Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3_text"
def __init__(
self,
vocab_size: int = 262_144,
hidden_size: int = 2304,
intermediate_size: int = 9216,
num_hidden_layers: int = 26,
num_attention_heads: int = 8,
num_key_value_heads: int = 4,
head_dim: int = 256,
sliding_window: int = 4096,
query_pre_attn_scalar: Optional[float] = 256,
rope_theta: float = 1_000_000.0,
rope_scaling=None,
rope_local_base_freq: float = 10_000.0,
sliding_window_pattern: int = 6,
rms_norm_eps: float = 1e-6,
hidden_activation: str = "gelu_pytorch_tanh",
pad_token_id: int = 0,
eos_token_id: int = 1,
bos_token_id: int = 2,
tie_word_embeddings: bool = True,
max_position_embeddings: int = 131_072,
initializer_range: float = 0.02,
attention_bias: bool = False,
attention_dropout: float = 0.0,
use_cache: bool = True,
final_logit_softcapping=None,
attn_logit_softcapping=None,
cache_implementation: str = "hybrid",
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_local_base_freq = rope_local_base_freq
# For configuring HybridCache to work with 5:1 attention pattern
self.sliding_window_pattern = sliding_window_pattern
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.query_pre_attn_scalar = query_pre_attn_scalar
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation
rope_config_validation(self)
class Gemma3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a Gemma3 Text config
>>> text_config = Gemma3TextConfig()
>>> # Initializing a Gemma3 gemma-3-4b style configuration
>>> configuration = Gemma3Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3"
sub_configs = {
"text_config": Gemma3TextConfig,
"vision_config": SiglipVisionConfig,
}
def __init__(
self,
text_config: Optional[Gemma3TextConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
image_token_index: int = 262_144,
initializer_range: float = 0.02,
**kwargs,
):
if text_config is None:
text_config = Gemma3TextConfig()
logger.info(
"text_config is None, using default Gemma3TextConfig vision config."
)
elif isinstance(text_config, dict):
text_config = Gemma3TextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = SiglipVisionConfig(**vision_config)
else:
vision_config = SiglipVisionConfig()
logger.info(
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
"to text tasks."
)
self.text_config = text_config
self.vision_config = vision_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)
__all__ = ["Gemma3Config", "Gemma3TextConfig"]

View File

@ -0,0 +1,463 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Gemma3."""
import itertools
import math
from typing import Dict, List, Optional, Union
import numpy as np
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from transformers.utils import (
TensorType,
filter_out_non_signature_kwargs,
is_vision_available,
logging,
)
from .utils import make_nested_list_of_images
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class Gemma3ImageProcessor(BaseImageProcessor):
r"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*):
Whether to apply `pan_and_scan` to images.
"""
model_input_names = ["pixel_values", "num_crops"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = False,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
do_pan_and_scan: bool = None,
pan_and_scan_min_crop_size: int = None,
pan_and_scan_max_num_crops: int = None,
pan_and_scan_min_ratio_to_activate: float = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.do_pan_and_scan = do_pan_and_scan
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
def pan_and_scan(
self,
image: np.ndarray,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pan and Scan and image, whatever it means. TODO: write-up docs
Args:
image (`np.ndarray`):
Image to resize.
pan_and_scan_min_crop_size (`int`):
Size of pan_and_scan_min_crop_size.
pan_and_scan_max_num_crops (`int`):
pan_and_scan_max_num_crops for the image.
pan_and_scan_min_ratio_to_activate (`int`):
pan_and_scan_min_ratio_to_activate for the image..
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
height, width = get_image_size(image)
# Square or landscape image.
if width >= height:
# Only apply PaS if the image is sufficiently exaggerated
if width / height < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_w = int(
math.floor(width / height + 0.5)
) # Half round up rounding.
num_crops_w = min(
int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w
)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
# Portrait image.
else:
# Only apply PaS if the image is sufficiently exaggerated
if height / width < pan_and_scan_min_ratio_to_activate:
return []
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
num_crops_h = int(math.floor(height / width + 0.5))
num_crops_h = min(
int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h
)
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(width / num_crops_w))
crop_size_h = int(math.ceil(height / num_crops_h))
# Don't apply PaS if crop size is too small.
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return []
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
if input_data_format == ChannelDimension.LAST:
image_crops = [
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(
crop_positions_h, crop_positions_w
)
]
else:
image_crops = [
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
for pos_h, pos_w in itertools.product(
crop_positions_h, crop_positions_w
)
]
return image_crops
def _process_images_for_pas(
self,
images: List[np.ndarray],
do_pan_and_scan: bool,
pan_and_scan_min_crop_size: int,
pan_and_scan_max_num_crops: int,
pan_and_scan_min_ratio_to_activate: float,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
pas_images_list = []
num_crops = []
for image in images:
pas_images = self.pan_and_scan(
image=image,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
pas_images_list.extend([image] + pas_images)
num_crops.append(len(pas_images))
return pas_images_list, num_crops
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: bool = None,
do_pan_and_scan: bool = None,
pan_and_scan_min_crop_size: int = None,
pan_and_scan_max_num_crops: int = None,
pan_and_scan_min_ratio_to_activate: float = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to apply `pan_and_scan` to images.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
do_pan_and_scan = (
do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan
)
pan_and_scan_min_crop_size = (
pan_and_scan_min_crop_size
if pan_and_scan_min_crop_size is not None
else self.pan_and_scan_min_crop_size
)
pan_and_scan_max_num_crops = (
pan_and_scan_max_num_crops
if pan_and_scan_max_num_crops is not None
else self.pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate = (
pan_and_scan_min_ratio_to_activate
if pan_and_scan_min_ratio_to_activate is not None
else self.pan_and_scan_min_ratio_to_activate
)
images_list = make_nested_list_of_images(images)
if not valid_images(images_list[0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images_list = [
[convert_to_rgb(image) for image in images] for images in images_list
]
# All transformations expect numpy arrays.
images_list = [
[to_numpy_array(image) for image in images] for images in images_list
]
if do_rescale and is_scaled_image(images_list[0][0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])
if do_pan_and_scan:
images_list_and_num_crops = [
self._process_images_for_pas(
images=images,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
data_format=data_format,
input_data_format=input_data_format,
)
for images in images_list
]
images_list = [images for images, _ in images_list_and_num_crops]
num_crops = [num_crops for _, num_crops in images_list_and_num_crops]
else:
num_crops = [[0] for images in images_list]
if do_resize:
height, width = size["height"], size["width"]
images_list = [
[
resize(
image=image,
size=(height, width),
resample=resample,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
if do_rescale:
images_list = [
[
self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
if do_normalize:
images_list = [
[
self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
images = [
to_channel_dimension_format(
image, data_format, input_channel_dim=input_data_format
)
for images in images_list
for image in images
]
data = {"pixel_values": images, "num_crops": num_crops}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma3ImageProcessor"]

View File

@ -0,0 +1,206 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List, Optional, Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import to_py_obj
from text_generation_server.models.custom_modeling.gemma3.image_processing_gemma3 import (
Gemma3ImageProcessor,
)
from transformers.image_utils import PILImageResampling
from .utils import make_nested_list_of_images
class Gemma3ImagesKwargs(ImagesKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
do_convert_rgb: Optional[bool]
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"do_pan_and_scan": False,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
},
}
class Gemma3Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
# # image_processor_class = "Gemma3ImageProcessor"
image_processor_class = "AutoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor,
tokenizer,
chat_template=None,
num_mm_soft_tokens_per_image: int = 256,
**kwargs,
):
num_mm_soft_tokens_per_image = 256
chat_template = None
image_processor = Gemma3ImageProcessor(
image_mean=(127.5,) * 3,
image_std=(127.5,) * 3,
size={"height": 896, "width": 896},
do_rescale=False,
resample=PILImageResampling.BILINEAR,
)
# import ipdb; ipdb.set_trace()
self.image_token_id = tokenizer.image_token_id
image_tokens_expanded = "".join(
[tokenizer.image_token] * num_mm_soft_tokens_per_image
)
self.full_image_sequence = (
f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
)
# import ipdb; ipdb.set_trace()
self.image_processor = image_processor
self.tokenizer = tokenizer
self.chat_template = chat_template
# super().__init__(
# image_processor=image_processor,
# tokenizer=tokenizer,
# chat_template=chat_template,
# **kwargs,
# )
def __call__(
self,
images: ImageInput = None,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
videos=None,
audio=None,
**kwargs: Unpack[Gemma3ProcessorKwargs],
) -> BatchFeature:
if text is None and images is None:
raise ValueError("Provide at least one of `text` or `images`.")
output_kwargs = self._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)
image_inputs = {}
if images is not None:
batched_images = make_nested_list_of_images(images)
image_inputs = self.image_processor(
batched_images, **output_kwargs["images_kwargs"]
)
# Create empty text to be replaced with placeholders
if not text:
text = [
" ".join(["<image>"] * len(images)) for images in batched_images
]
if len(batched_images) != len(text):
raise ValueError(
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
)
# Replace image tokens by the full expanded sequence
batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
for prompt, images, num_crops in zip(text, batched_images, batch_num_crops):
image_indexes = [m.start() for m in re.finditer("<image>", prompt)]
if len(images) != len(image_indexes):
raise ValueError(
f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images."
)
# Insert additional image tokens for Pan-and-Scan crops
for num, idx in reversed(list(zip(num_crops, image_indexes))):
if num:
formatted_image_text = (
"Here is the original image <image> and here are some crops to help you see better "
+ " ".join(["<image>"] * num)
)
prompt = (
prompt[:idx]
+ formatted_image_text
+ prompt[idx + len("<image>") :]
)
# Expand placeholder image tokens to the full image token sequence
text = [
prompt.replace("<image>", self.full_image_sequence) for prompt in text
]
text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_input, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["Gemma3Processor"]

View File

@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from transformers.image_utils import ImageInput, is_valid_image, is_pil_image
def is_valid_list_of_images(images: List):
return images and all(is_valid_image(image) for image in images)
def make_nested_list_of_images(
images: Union[List[ImageInput], ImageInput],
) -> ImageInput:
"""
Ensure that the output is a nested list of images.
Args:
images (`Union[List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
# If it's a list of batches, it's already in the right format
if (
isinstance(images, (list, tuple))
and all(isinstance(images_i, (list, tuple)) for images_i in images)
and all(is_valid_list_of_images(images_i) for images_i in images)
):
return images
# If it's a list of images, it's a single batch, so convert it to a list of lists
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
return [images]
if images[0].ndim == 4:
return [list(image) for image in images]
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
return [[images]]
if images.ndim == 4:
return [list(images)]
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)

View File

@ -23,6 +23,13 @@ def load_text_model(prefix, config, weights, name=None):
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "gemma3" or config.model_type == "gemma3_text":
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
FlashGemma3ForCausalLM,
)
return FlashGemma3ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
@ -42,13 +49,21 @@ def load_vision_model(prefix, config, weights):
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
if config.model_type == "siglip_vision_model":
if (
config.model_type == "siglip_vision_model"
or config.model_type == "gemma3_vision"
):
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
# TODO: ensure that using the prefix doesn't break any existing models
# that rely on the old prefix (update the old models if necessary)
return SiglipVisionTransformer(
prefix="vision_tower.vision_model", config=config, weights=weights
# prefix="vision_model.vision_model", config=config, weights=weights
prefix=f"{prefix}.vision_model",
config=config,
weights=weights,
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -128,6 +128,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -244,6 +250,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
else:
images.append([image])
else: