diff --git a/router/src/config.rs b/router/src/config.rs index 64663a7a..9b5a2404 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct LlavaNext { - text_config: Box, + text_config: TextConfig, vision_config: VisionConfig, image_grid_pinpoints: Vec<(usize, usize)>, } @@ -32,9 +32,12 @@ fn select_best_resolution( let mut min_wasted_resolution = f32::NEG_INFINITY; for (height, width) in possible_resolutions { - // let scale = std::cmp::min(width / original_width, height / original_height); - let downscaled_width = width / original_width * original_width; - let downscaled_height = height / original_height * original_height; + let wscale = *width as f32 / original_width as f32; + let hscale = *height as f32 / original_height as f32; + // f32 partial ord. + let scale = if wscale > hscale { hscale } else { wscale }; + let downscaled_width = (*width as f32 * scale) as usize; + let downscaled_height = (*height as f32 * scale) as usize; let effective_resolution = std::cmp::min( downscaled_width * downscaled_height, original_width * original_height, @@ -51,7 +54,7 @@ fn select_best_resolution( } } - best_fit.expect("Expect a resolution to exist") + best_fit.unwrap_or((original_height, original_width)) } impl LlavaNext { @@ -89,11 +92,67 @@ pub enum Config { ClipVisionModel(ClipVisionModel), Mistral, Idefics, + Ssm, + GptBigcode, + Santacoder, + Bloom, + Mpt, + GptNeox, + Phi, + #[serde(rename = "phi-msft")] + PhiMsft, + Llama, + Baichuan, + Gemma, + Cohere, + Drbx, + Falcon, + Mixtral, + Starcoder2, + Qwen2, + Opt, + T5, } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct TextConfig {} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct VisionConfig { image_size: usize, patch_size: usize, } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_llava_next_features() { + let config = LlavaNext { + text_config: TextConfig {}, + vision_config: VisionConfig { + image_size: 336, + patch_size: 14, + }, + image_grid_pinpoints: vec![ + (336, 672), + (672, 336), + (672, 672), + (1008, 336), + (336, 1008), + ], + }; + + let slots = config.get_number_of_features(640, 640); + assert_eq!(slots, 2928); + let slots = config.get_number_of_features(480, 640); + assert_eq!(slots, 2340); + let slots = config.get_number_of_features(899, 1024); + assert_eq!(slots, 2732); + let slots = config.get_number_of_features(1024, 899); + assert_eq!(slots, 3320); + } +} diff --git a/router/src/main.rs b/router/src/main.rs index f180d65b..aace2ff9 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -218,13 +218,14 @@ async fn main() -> Result<(), RouterError> { }; let config: Option = api_repo.get("config.json").await.ok().and_then(|filename| { - tracing::info!("Config filename {filename:?}"); std::fs::read_to_string(filename) .ok() .as_ref() .and_then(|c| { let config: Result = serde_json::from_str(c); - tracing::info!("Config parse {config:?}"); + if let Err(err) = &config { + tracing::warn!("Could not parse config {err:?}"); + } config.ok() }) });