diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d1041e26e..fde1472f0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -152,7 +152,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> "flashdecoding" }; - match config.head_dim { + match config.get_head_dim() { Some(h) if h == 64 || h == 128 || h == 256 => { if lora_adapters.is_some() && prefix_caching.is_none() { tracing::info!("Disabling prefix caching because of lora adapters"); @@ -214,6 +214,7 @@ struct RawConfig { num_key_value_heads: Option, num_hidden_layers: Option, head_dim: Option, + text_config: Option, vision_config: Option, is_encoder_decoder: Option, #[serde(rename = "num_experts_per_tok")] @@ -233,6 +234,11 @@ struct QuantizationConfig { #[derive(Debug, Deserialize)] struct VisionConfig {} +#[derive(Debug, Deserialize)] +struct TextConfig { + head_dim: Option, +} + #[derive(Debug, Deserialize)] struct Config { max_position_embeddings: Option, @@ -244,6 +250,7 @@ struct Config { intermediate_size: Option, hidden_size: Option, model_type: Option, + text_config: Option, vision_config: Option, is_encoder_decoder: bool, num_experts_per_token: usize, @@ -253,6 +260,14 @@ struct Config { } impl Config { + fn get_head_dim(&self) -> Option { + self.head_dim.or_else(|| { + self.text_config + .as_ref() + .and_then(|text_config| text_config.head_dim) + }) + } + fn flop(&self) -> Option { if self.vision_config.is_some() { // VLM are much harder to predict and VRAM requirements @@ -261,7 +276,7 @@ impl Config { } let num_heads = self.num_heads? as u64; let num_kv_heads = self.num_kv_heads? as u64; - let head_dim = self.head_dim? as u64; + let head_dim = self.get_head_dim()? as u64; let hidden_size = self.hidden_size? as u64; let intermediate_size = (self.intermediate_size? * (self.num_experts_per_token + self.num_shared_experts)) @@ -289,7 +304,7 @@ impl Config { } // 2 for key and values // 2 for f16 dtype? - Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?) + Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?) } fn mlp_vram_per_tok(&self) -> Option { @@ -310,8 +325,8 @@ impl Config { } fn model_vram(&self) -> Option { - let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?; - let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?; + let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?; + let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?; // gate + up + down = 3 let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?; let layer_vram = mlp_vram + attn_vram + o_vram; @@ -349,6 +364,7 @@ impl From for Config { let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads); let intermediate_size = other.intermediate_size; let model_type = other.model_type; + let text_config = other.text_config; let vision_config = other.vision_config; let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); let num_experts_per_token = other.num_experts_per_token.unwrap_or(1); @@ -360,6 +376,7 @@ impl From for Config { quantize, head_dim, model_type, + text_config, vision_config, is_encoder_decoder, hidden_size,