diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9001a4d5b..671ec2ee5 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -172,7 +172,9 @@ struct RawConfig { vision_config: Option, is_encoder_decoder: Option, #[serde(rename = "num_experts_per_tok")] - experts: Option, + num_experts_per_token: Option, + #[serde(rename = "n_shared_experts")] + num_shared_experts: Option, } #[derive(Deserialize)] @@ -196,7 +198,8 @@ struct Config { model_type: Option, vision_config: Option, is_encoder_decoder: bool, - experts: Option, + num_experts_per_token: usize, + num_shared_experts: usize, } impl Config { @@ -210,11 +213,9 @@ impl Config { let num_kv_heads = self.num_kv_heads? as u64; let head_dim = self.head_dim? as u64; let hidden_size = self.hidden_size? as u64; - let intermediate_size = if let Some(experts) = self.experts { - (self.intermediate_size? * experts) as u64 - } else { - self.intermediate_size? as u64 - }; + let intermediate_size = (self.intermediate_size? + * (self.num_experts_per_token + self.num_shared_experts)) + as u64; let num_layers = self.num_layers? as u64; let q_flops = 2 * num_heads * head_dim * hidden_size; @@ -257,7 +258,8 @@ impl From for Config { let model_type = other.model_type; let vision_config = other.vision_config; let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); - let experts = other.experts; + let num_experts_per_token = other.num_experts_per_token.unwrap_or(1); + let num_shared_experts = other.num_shared_experts.unwrap_or(0); Config { max_position_embeddings, quantize, @@ -270,7 +272,8 @@ impl From for Config { num_kv_heads, intermediate_size, num_layers, - experts, + num_experts_per_token, + num_shared_experts, } } } @@ -1547,6 +1550,7 @@ impl ComputeType { // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf "nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)), // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + "nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)), "nvidia-a100" => Some(312 * 10u64.pow(12)), card => { tracing::warn!("Unkown compute for card {card}");