mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing select_best_resolution.
This commit is contained in:
parent
61821f410a
commit
8c114e5fc4
@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct LlavaNext {
|
||||
text_config: Box<Config>,
|
||||
text_config: TextConfig,
|
||||
vision_config: VisionConfig,
|
||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
}
|
||||
@ -32,9 +32,12 @@ fn select_best_resolution(
|
||||
let mut min_wasted_resolution = f32::NEG_INFINITY;
|
||||
|
||||
for (height, width) in possible_resolutions {
|
||||
// let scale = std::cmp::min(width / original_width, height / original_height);
|
||||
let downscaled_width = width / original_width * original_width;
|
||||
let downscaled_height = height / original_height * original_height;
|
||||
let wscale = *width as f32 / original_width as f32;
|
||||
let hscale = *height as f32 / original_height as f32;
|
||||
// f32 partial ord.
|
||||
let scale = if wscale > hscale { hscale } else { wscale };
|
||||
let downscaled_width = (*width as f32 * scale) as usize;
|
||||
let downscaled_height = (*height as f32 * scale) as usize;
|
||||
let effective_resolution = std::cmp::min(
|
||||
downscaled_width * downscaled_height,
|
||||
original_width * original_height,
|
||||
@ -51,7 +54,7 @@ fn select_best_resolution(
|
||||
}
|
||||
}
|
||||
|
||||
best_fit.expect("Expect a resolution to exist")
|
||||
best_fit.unwrap_or((original_height, original_width))
|
||||
}
|
||||
|
||||
impl LlavaNext {
|
||||
@ -89,11 +92,67 @@ pub enum Config {
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Gemma,
|
||||
Cohere,
|
||||
Drbx,
|
||||
Falcon,
|
||||
Mixtral,
|
||||
Starcoder2,
|
||||
Qwen2,
|
||||
Opt,
|
||||
T5,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TextConfig {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct VisionConfig {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_llava_next_features() {
|
||||
let config = LlavaNext {
|
||||
text_config: TextConfig {},
|
||||
vision_config: VisionConfig {
|
||||
image_size: 336,
|
||||
patch_size: 14,
|
||||
},
|
||||
image_grid_pinpoints: vec![
|
||||
(336, 672),
|
||||
(672, 336),
|
||||
(672, 672),
|
||||
(1008, 336),
|
||||
(336, 1008),
|
||||
],
|
||||
};
|
||||
|
||||
let slots = config.get_number_of_features(640, 640);
|
||||
assert_eq!(slots, 2928);
|
||||
let slots = config.get_number_of_features(480, 640);
|
||||
assert_eq!(slots, 2340);
|
||||
let slots = config.get_number_of_features(899, 1024);
|
||||
assert_eq!(slots, 2732);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
}
|
||||
}
|
||||
|
@ -218,13 +218,14 @@ async fn main() -> Result<(), RouterError> {
|
||||
};
|
||||
|
||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||
tracing::info!("Config filename {filename:?}");
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
tracing::info!("Config parse {config:?}");
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user