mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing select_best_resolution.
This commit is contained in:
parent
61821f410a
commit
8c114e5fc4
@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct LlavaNext {
|
pub struct LlavaNext {
|
||||||
text_config: Box<Config>,
|
text_config: TextConfig,
|
||||||
vision_config: VisionConfig,
|
vision_config: VisionConfig,
|
||||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
image_grid_pinpoints: Vec<(usize, usize)>,
|
||||||
}
|
}
|
||||||
@ -32,9 +32,12 @@ fn select_best_resolution(
|
|||||||
let mut min_wasted_resolution = f32::NEG_INFINITY;
|
let mut min_wasted_resolution = f32::NEG_INFINITY;
|
||||||
|
|
||||||
for (height, width) in possible_resolutions {
|
for (height, width) in possible_resolutions {
|
||||||
// let scale = std::cmp::min(width / original_width, height / original_height);
|
let wscale = *width as f32 / original_width as f32;
|
||||||
let downscaled_width = width / original_width * original_width;
|
let hscale = *height as f32 / original_height as f32;
|
||||||
let downscaled_height = height / original_height * original_height;
|
// f32 partial ord.
|
||||||
|
let scale = if wscale > hscale { hscale } else { wscale };
|
||||||
|
let downscaled_width = (*width as f32 * scale) as usize;
|
||||||
|
let downscaled_height = (*height as f32 * scale) as usize;
|
||||||
let effective_resolution = std::cmp::min(
|
let effective_resolution = std::cmp::min(
|
||||||
downscaled_width * downscaled_height,
|
downscaled_width * downscaled_height,
|
||||||
original_width * original_height,
|
original_width * original_height,
|
||||||
@ -51,7 +54,7 @@ fn select_best_resolution(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
best_fit.expect("Expect a resolution to exist")
|
best_fit.unwrap_or((original_height, original_width))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlavaNext {
|
impl LlavaNext {
|
||||||
@ -89,11 +92,67 @@ pub enum Config {
|
|||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Ssm,
|
||||||
|
GptBigcode,
|
||||||
|
Santacoder,
|
||||||
|
Bloom,
|
||||||
|
Mpt,
|
||||||
|
GptNeox,
|
||||||
|
Phi,
|
||||||
|
#[serde(rename = "phi-msft")]
|
||||||
|
PhiMsft,
|
||||||
|
Llama,
|
||||||
|
Baichuan,
|
||||||
|
Gemma,
|
||||||
|
Cohere,
|
||||||
|
Drbx,
|
||||||
|
Falcon,
|
||||||
|
Mixtral,
|
||||||
|
Starcoder2,
|
||||||
|
Qwen2,
|
||||||
|
Opt,
|
||||||
|
T5,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct TextConfig {}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct VisionConfig {
|
pub struct VisionConfig {
|
||||||
image_size: usize,
|
image_size: usize,
|
||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_llava_next_features() {
|
||||||
|
let config = LlavaNext {
|
||||||
|
text_config: TextConfig {},
|
||||||
|
vision_config: VisionConfig {
|
||||||
|
image_size: 336,
|
||||||
|
patch_size: 14,
|
||||||
|
},
|
||||||
|
image_grid_pinpoints: vec![
|
||||||
|
(336, 672),
|
||||||
|
(672, 336),
|
||||||
|
(672, 672),
|
||||||
|
(1008, 336),
|
||||||
|
(336, 1008),
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let slots = config.get_number_of_features(640, 640);
|
||||||
|
assert_eq!(slots, 2928);
|
||||||
|
let slots = config.get_number_of_features(480, 640);
|
||||||
|
assert_eq!(slots, 2340);
|
||||||
|
let slots = config.get_number_of_features(899, 1024);
|
||||||
|
assert_eq!(slots, 2732);
|
||||||
|
let slots = config.get_number_of_features(1024, 899);
|
||||||
|
assert_eq!(slots, 3320);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -218,13 +218,14 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||||
tracing::info!("Config filename {filename:?}");
|
|
||||||
std::fs::read_to_string(filename)
|
std::fs::read_to_string(filename)
|
||||||
.ok()
|
.ok()
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|c| {
|
.and_then(|c| {
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
let config: Result<Config, _> = serde_json::from_str(c);
|
||||||
tracing::info!("Config parse {config:?}");
|
if let Err(err) = &config {
|
||||||
|
tracing::warn!("Could not parse config {err:?}");
|
||||||
|
}
|
||||||
config.ok()
|
config.ok()
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user