fix: limit vision flop calc to qwen2 vl models and update config typing

This commit is contained in:
drbh 2025-01-22 18:30:03 +00:00
parent d12e075966
commit a0ab962b6d
2 changed files with 56 additions and 38 deletions

View File

@ -231,12 +231,12 @@ struct QuantizationConfig {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct VisionConfig { struct VisionConfig {
depth: usize, depth: Option<usize>,
embed_dim: usize, embed_dim: Option<usize>,
mlp_ratio: usize, mlp_ratio: Option<usize>,
in_chans: usize, in_chans: Option<usize>,
patch_size: usize, patch_size: Option<usize>,
temporal_patch_size: usize, temporal_patch_size: Option<usize>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -283,17 +283,27 @@ impl Config {
tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop")); tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop"));
if let Some(vision_config) = self.vision_config.as_ref() { // text-only case
let in_chans = vision_config.in_chans as u64; if self.vision_config.is_none() {
let patch_size = vision_config.patch_size as u64; return Some(text_flops);
let embed_dim = vision_config.embed_dim as u64; }
let vision_depth = vision_config.depth as u64;
let mlp_ratio = vision_config.mlp_ratio as u64; let vision_config = self.vision_config.as_ref().unwrap();
let temporal_patch_size = vision_config.temporal_patch_size as u64;
// estimate vision flops for specific model types
match self.model_type.as_deref() {
Some("qwen2_vl") => {
let in_chans = vision_config.in_chans? as u64;
let patch_size = vision_config.patch_size? as u64;
let embed_dim = vision_config.embed_dim? as u64;
let vision_depth = vision_config.depth? as u64;
let mlp_ratio = vision_config.mlp_ratio? as u64;
let temporal_patch_size = vision_config.temporal_patch_size? as u64;
// 1. patch embedding: // 1. patch embedding:
// - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 // - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2
// where the 2 accounts for multiply-add // where the 2 accounts for multiply-add
let patch_flops = 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; let patch_flops =
2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans;
// 2. self-attention + mlp: // 2. self-attention + mlp:
// - qkv projections: 3 * d_model * d_model * 2 // - qkv projections: 3 * d_model * d_model * 2
// - attention: d_model * d_model * 2 // - attention: d_model * d_model * 2
@ -308,8 +318,10 @@ impl Config {
human_size(vision_flops as usize, "flop") human_size(vision_flops as usize, "flop")
); );
Some(text_flops + vision_flops) Some(text_flops + vision_flops)
} else { }
Some(text_flops) // model has a vision config but is not supported for flops calculation
// we return None to avoid overestimating the memory requirements
_ => return None,
} }
} }

View File

@ -86,15 +86,21 @@ class PositionRotaryEmbedding(nn.Module):
# `rope_type` is now standard in transformers, but some existing models # `rope_type` is now standard in transformers, but some existing models
# have `type` instead. # have `type` instead.
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
mrope_section = rope_scaling.get("mrope_section", None)
if rope_type == "linear": # only apply mrope if sections are provided and the rope type is mrope or default
pass if mrope_section is not None and (
elif rope_type == "default": rope_type == "mrope" or rope_type == "default"
if rope_scaling.get("mrope_section", False): ):
mrope_section = rope_scaling.get("mrope_section") mrope_section = rope_scaling.get("mrope_section")
return RotaryPositionEmbeddingMultimodalSections( return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section inv_freq, scaling_factor, mrope_section
) )
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "dynamic": elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding( return DynamicPositionRotaryEmbedding(