fix: limit vision flop calc to qwen2 vl models and update config typing

This commit is contained in:
drbh 2025-01-22 18:30:03 +00:00
parent d12e075966
commit a0ab962b6d
2 changed files with 56 additions and 38 deletions

View File

@ -231,12 +231,12 @@ struct QuantizationConfig {
#[derive(Debug, Deserialize)]
struct VisionConfig {
depth: usize,
embed_dim: usize,
mlp_ratio: usize,
in_chans: usize,
patch_size: usize,
temporal_patch_size: usize,
depth: Option<usize>,
embed_dim: Option<usize>,
mlp_ratio: Option<usize>,
in_chans: Option<usize>,
patch_size: Option<usize>,
temporal_patch_size: Option<usize>,
}
#[derive(Debug, Deserialize)]
@ -283,33 +283,45 @@ impl Config {
tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop"));
if let Some(vision_config) = self.vision_config.as_ref() {
let in_chans = vision_config.in_chans as u64;
let patch_size = vision_config.patch_size as u64;
let embed_dim = vision_config.embed_dim as u64;
let vision_depth = vision_config.depth as u64;
let mlp_ratio = vision_config.mlp_ratio as u64;
let temporal_patch_size = vision_config.temporal_patch_size as u64;
// 1. patch embedding:
// - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2
// where the 2 accounts for multiply-add
let patch_flops = 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans;
// 2. self-attention + mlp:
// - qkv projections: 3 * d_model * d_model * 2
// - attention: d_model * d_model * 2
// - mlp: 2 * d_model * (mlp_ratio * d_model) * 2
// simplified to: 2 * d_model * (4 + mlp_ratio * d_model)
let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim);
// 3. add with layer norm flops for total vision layer flops
let layer_flops = patch_flops + attn_flops + 2 * embed_dim;
let vision_flops = layer_flops * vision_depth;
tracing::debug!(
"Vision flops: {}",
human_size(vision_flops as usize, "flop")
);
Some(text_flops + vision_flops)
} else {
Some(text_flops)
// text-only case
if self.vision_config.is_none() {
return Some(text_flops);
}
let vision_config = self.vision_config.as_ref().unwrap();
// estimate vision flops for specific model types
match self.model_type.as_deref() {
Some("qwen2_vl") => {
let in_chans = vision_config.in_chans? as u64;
let patch_size = vision_config.patch_size? as u64;
let embed_dim = vision_config.embed_dim? as u64;
let vision_depth = vision_config.depth? as u64;
let mlp_ratio = vision_config.mlp_ratio? as u64;
let temporal_patch_size = vision_config.temporal_patch_size? as u64;
// 1. patch embedding:
// - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2
// where the 2 accounts for multiply-add
let patch_flops =
2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans;
// 2. self-attention + mlp:
// - qkv projections: 3 * d_model * d_model * 2
// - attention: d_model * d_model * 2
// - mlp: 2 * d_model * (mlp_ratio * d_model) * 2
// simplified to: 2 * d_model * (4 + mlp_ratio * d_model)
let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim);
// 3. add with layer norm flops for total vision layer flops
let layer_flops = patch_flops + attn_flops + 2 * embed_dim;
let vision_flops = layer_flops * vision_depth;
tracing::debug!(
"Vision flops: {}",
human_size(vision_flops as usize, "flop")
);
Some(text_flops + vision_flops)
}
// model has a vision config but is not supported for flops calculation
// we return None to avoid overestimating the memory requirements
_ => return None,
}
}

View File

@ -86,15 +86,21 @@ class PositionRotaryEmbedding(nn.Module):
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
mrope_section = rope_scaling.get("mrope_section", None)
# only apply mrope if sections are provided and the rope type is mrope or default
if mrope_section is not None and (
rope_type == "mrope" or rope_type == "default"
):
mrope_section = rope_scaling.get("mrope_section")
return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section
)
if rope_type == "linear":
pass
elif rope_type == "default":
if rope_scaling.get("mrope_section", False):
mrope_section = rope_scaling.get("mrope_section")
return RotaryPositionEmbeddingMultimodalSections(
inv_freq, scaling_factor, mrope_section
)
pass
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(