Fixing select_best_resolution.

This commit is contained in:
Nicolas Patry 2024-04-09 17:16:15 +00:00
parent 61821f410a
commit 8c114e5fc4
2 changed files with 67 additions and 7 deletions

View File

@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
text_config: Box<Config>,
text_config: TextConfig,
vision_config: VisionConfig,
image_grid_pinpoints: Vec<(usize, usize)>,
}
@ -32,9 +32,12 @@ fn select_best_resolution(
let mut min_wasted_resolution = f32::NEG_INFINITY;
for (height, width) in possible_resolutions {
// let scale = std::cmp::min(width / original_width, height / original_height);
let downscaled_width = width / original_width * original_width;
let downscaled_height = height / original_height * original_height;
let wscale = *width as f32 / original_width as f32;
let hscale = *height as f32 / original_height as f32;
// f32 partial ord.
let scale = if wscale > hscale { hscale } else { wscale };
let downscaled_width = (*width as f32 * scale) as usize;
let downscaled_height = (*height as f32 * scale) as usize;
let effective_resolution = std::cmp::min(
downscaled_width * downscaled_height,
original_width * original_height,
@ -51,7 +54,7 @@ fn select_best_resolution(
}
}
best_fit.expect("Expect a resolution to exist")
best_fit.unwrap_or((original_height, original_width))
}
impl LlavaNext {
@ -89,11 +92,67 @@ pub enum Config {
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Ssm,
GptBigcode,
Santacoder,
Bloom,
Mpt,
GptNeox,
Phi,
#[serde(rename = "phi-msft")]
PhiMsft,
Llama,
Baichuan,
Gemma,
Cohere,
Drbx,
Falcon,
Mixtral,
Starcoder2,
Qwen2,
Opt,
T5,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextConfig {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
image_size: usize,
patch_size: usize,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_llava_next_features() {
let config = LlavaNext {
text_config: TextConfig {},
vision_config: VisionConfig {
image_size: 336,
patch_size: 14,
},
image_grid_pinpoints: vec![
(336, 672),
(672, 336),
(672, 672),
(1008, 336),
(336, 1008),
],
};
let slots = config.get_number_of_features(640, 640);
assert_eq!(slots, 2928);
let slots = config.get_number_of_features(480, 640);
assert_eq!(slots, 2340);
let slots = config.get_number_of_features(899, 1024);
assert_eq!(slots, 2732);
let slots = config.get_number_of_features(1024, 899);
assert_eq!(slots, 3320);
}
}

View File

@ -218,13 +218,14 @@ async fn main() -> Result<(), RouterError> {
};
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
tracing::info!("Config filename {filename:?}");
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
tracing::info!("Config parse {config:?}");
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});