launcher: correctly get the head dimension for VLMs

For most (?) VLMs, the head dimension is in the `text_config`
configuration section. However, since we only queried the top-level
`head_dim` (which typically doesn't exist in VLMs), we would never use
flashinfer. This change adds a method that gets the head dimension from
the top-level `Config` struct or `text_config` when that fails.
This commit is contained in:
Daniël de Kok 2025-03-17 10:07:39 +00:00
parent f91434e99b
commit 4727a3af67

View File

@ -152,7 +152,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
"flashdecoding" "flashdecoding"
}; };
match config.head_dim { match config.get_head_dim() {
Some(h) if h == 64 || h == 128 || h == 256 => { Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() { if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters"); tracing::info!("Disabling prefix caching because of lora adapters");
@ -214,6 +214,7 @@ struct RawConfig {
num_key_value_heads: Option<usize>, num_key_value_heads: Option<usize>,
num_hidden_layers: Option<usize>, num_hidden_layers: Option<usize>,
head_dim: Option<usize>, head_dim: Option<usize>,
text_config: Option<TextConfig>,
vision_config: Option<VisionConfig>, vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>, is_encoder_decoder: Option<bool>,
#[serde(rename = "num_experts_per_tok")] #[serde(rename = "num_experts_per_tok")]
@ -233,6 +234,11 @@ struct QuantizationConfig {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct VisionConfig {} struct VisionConfig {}
#[derive(Debug, Deserialize)]
struct TextConfig {
head_dim: Option<usize>,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Config { struct Config {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
@ -244,6 +250,7 @@ struct Config {
intermediate_size: Option<usize>, intermediate_size: Option<usize>,
hidden_size: Option<usize>, hidden_size: Option<usize>,
model_type: Option<String>, model_type: Option<String>,
text_config: Option<TextConfig>,
vision_config: Option<VisionConfig>, vision_config: Option<VisionConfig>,
is_encoder_decoder: bool, is_encoder_decoder: bool,
num_experts_per_token: usize, num_experts_per_token: usize,
@ -253,6 +260,14 @@ struct Config {
} }
impl Config { impl Config {
fn get_head_dim(&self) -> Option<usize> {
self.head_dim.or_else(|| {
self.text_config
.as_ref()
.and_then(|text_config| text_config.head_dim)
})
}
fn flop(&self) -> Option<u64> { fn flop(&self) -> Option<u64> {
if self.vision_config.is_some() { if self.vision_config.is_some() {
// VLM are much harder to predict and VRAM requirements // VLM are much harder to predict and VRAM requirements
@ -261,7 +276,7 @@ impl Config {
} }
let num_heads = self.num_heads? as u64; let num_heads = self.num_heads? as u64;
let num_kv_heads = self.num_kv_heads? as u64; let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64; let head_dim = self.get_head_dim()? as u64;
let hidden_size = self.hidden_size? as u64; let hidden_size = self.hidden_size? as u64;
let intermediate_size = (self.intermediate_size? let intermediate_size = (self.intermediate_size?
* (self.num_experts_per_token + self.num_shared_experts)) * (self.num_experts_per_token + self.num_shared_experts))
@ -289,7 +304,7 @@ impl Config {
} }
// 2 for key and values // 2 for key and values
// 2 for f16 dtype? // 2 for f16 dtype?
Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?) Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)
} }
fn mlp_vram_per_tok(&self) -> Option<usize> { fn mlp_vram_per_tok(&self) -> Option<usize> {
@ -310,8 +325,8 @@ impl Config {
} }
fn model_vram(&self) -> Option<usize> { fn model_vram(&self) -> Option<usize> {
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?; let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;
let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?; let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;
// gate + up + down = 3 // gate + up + down = 3
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?; let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
let layer_vram = mlp_vram + attn_vram + o_vram; let layer_vram = mlp_vram + attn_vram + o_vram;
@ -349,6 +364,7 @@ impl From<RawConfig> for Config {
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads); let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
let intermediate_size = other.intermediate_size; let intermediate_size = other.intermediate_size;
let model_type = other.model_type; let model_type = other.model_type;
let text_config = other.text_config;
let vision_config = other.vision_config; let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1); let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
@ -360,6 +376,7 @@ impl From<RawConfig> for Config {
quantize, quantize,
head_dim, head_dim,
model_type, model_type,
text_config,
vision_config, vision_config,
is_encoder_decoder, is_encoder_decoder,
hidden_size, hidden_size,