diff --git a/router/src/config.rs b/router/src/config.rs index fcda26122..93b6f4fa4 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -229,10 +229,13 @@ impl Llama4 { pub fn pixel_shuffle_ratio(&self) -> f64 { self.vision_config.pixel_shuffle_ratio } - pub fn get_aspect_ratios(&self, height: usize, width: usize) -> (usize, usize) { + pub fn get_aspect_ratios( + &self, + height: usize, + width: usize, + max_chunks: usize, + ) -> (usize, usize) { let patch_size = self.vision_config.image_size; - // How to avoid hardcoding this? - let max_chunks = 15; let supported = find_supported_resolutions(max_chunks, patch_size); let (target_h, target_w) = get_best_fit(height, width, &supported, false); (target_h / patch_size, target_w / patch_size) diff --git a/router/src/lib.rs b/router/src/lib.rs index 50adb5cf6..3c1a01b3c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -204,7 +204,7 @@ pub struct Gemma3Processor { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Llama4Processor { #[serde(default)] - do_image_splitting: bool, + max_patches: usize, } #[derive(Debug, Clone, Deserialize, Default)] diff --git a/router/src/validation.rs b/router/src/validation.rs index dfe9dd4d2..b29391b77 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -698,10 +698,14 @@ fn image_tokens( let image_height = config.image_size(); let patch_size = config.patch_size(); let pixel_shuffle_ratio = config.pixel_shuffle_ratio(); + let max_patches = match preprocessor_config { + Some(HubPreprocessorConfig::Llama4Processor(cfg)) => cfg.max_patches, + _ => panic!("Expected Llama4Processor in preprocessor_config"), + }; let downsample_ratio = (1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize; - let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width); + let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width, max_patches); let image_width = image_height; // Assuming pixel shape: [H][W][C] let num_patches_per_chunk = diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 84437bf32..291ee5fba 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1041,7 +1041,6 @@ def get_model( trust_remote_code=trust_remote_code, processor_kwargs={ "use_fast": True, - "size": {"height": 336, "width": 336}, }, ) elif model_type == BAICHUAN: