add fix
This commit is contained in:
Mohit Sharma 2025-04-14 22:13:53 +05:30 committed by GitHub
parent fe56f760df
commit 73e797528d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 6 deletions

View File

@ -229,10 +229,13 @@ impl Llama4 {
pub fn pixel_shuffle_ratio(&self) -> f64 { pub fn pixel_shuffle_ratio(&self) -> f64 {
self.vision_config.pixel_shuffle_ratio self.vision_config.pixel_shuffle_ratio
} }
pub fn get_aspect_ratios(&self, height: usize, width: usize) -> (usize, usize) { pub fn get_aspect_ratios(
&self,
height: usize,
width: usize,
max_chunks: usize,
) -> (usize, usize) {
let patch_size = self.vision_config.image_size; let patch_size = self.vision_config.image_size;
// How to avoid hardcoding this?
let max_chunks = 15;
let supported = find_supported_resolutions(max_chunks, patch_size); let supported = find_supported_resolutions(max_chunks, patch_size);
let (target_h, target_w) = get_best_fit(height, width, &supported, false); let (target_h, target_w) = get_best_fit(height, width, &supported, false);
(target_h / patch_size, target_w / patch_size) (target_h / patch_size, target_w / patch_size)

View File

@ -204,7 +204,7 @@ pub struct Gemma3Processor {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Llama4Processor { pub struct Llama4Processor {
#[serde(default)] #[serde(default)]
do_image_splitting: bool, max_patches: usize,
} }
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]

View File

@ -698,10 +698,14 @@ fn image_tokens(
let image_height = config.image_size(); let image_height = config.image_size();
let patch_size = config.patch_size(); let patch_size = config.patch_size();
let pixel_shuffle_ratio = config.pixel_shuffle_ratio(); let pixel_shuffle_ratio = config.pixel_shuffle_ratio();
let max_patches = match preprocessor_config {
Some(HubPreprocessorConfig::Llama4Processor(cfg)) => cfg.max_patches,
_ => panic!("Expected Llama4Processor in preprocessor_config"),
};
let downsample_ratio = let downsample_ratio =
(1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize; (1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize;
let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width); let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width, max_patches);
let image_width = image_height; // Assuming pixel shape: [H][W][C] let image_width = image_height; // Assuming pixel shape: [H][W][C]
let num_patches_per_chunk = let num_patches_per_chunk =

View File

@ -1041,7 +1041,6 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
processor_kwargs={ processor_kwargs={
"use_fast": True, "use_fast": True,
"size": {"height": 336, "width": 336},
}, },
) )
elif model_type == BAICHUAN: elif model_type == BAICHUAN: