diff --git a/router/src/config.rs b/router/src/config.rs index 7c595e2d..a1a3c7f6 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -265,6 +265,10 @@ impl Idefics3 { pub fn get_max_longest_edge_for_image_resize(&self) -> usize { 1456 } + + pub fn get_max_image_size(&self) -> usize { + 4096 + } } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/router/src/validation.rs b/router/src/validation.rs index 28c7f2f8..42e43897 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -646,11 +646,10 @@ fn image_tokens( const GLOBAL_IMG: &str = ""; let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize(); + let max_image_size = config.get_max_image_size(); - // resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio - let (height, width) = if height > max_longest_edge_for_image_resize - || width > max_longest_edge_for_image_resize - { + // resize image to max_longest_edge_for_image_resize and keep aspect ratio + let (height, width) = { let aspect_ratio = height as f32 / width as f32; if height > width { ( @@ -663,8 +662,23 @@ fn image_tokens( max_longest_edge_for_image_resize, ) } - } else { - (height, width) + }; + + let (height, width) = { + let aspect_ratio = height as f32 / width as f32; + if height >= width && height > max_image_size { + ( + max_image_size, + (max_image_size as f32 / aspect_ratio) as usize, + ) + } else if width > height && width > max_image_size { + ( + (max_image_size as f32 * aspect_ratio) as usize, + max_image_size, + ) + } else { + (height, width) + } }; let image_seq_len = config.get_number_of_features();