Adding A100 compute. (#2806)

This commit is contained in:
Nicolas Patry 2024-12-06 22:49:15 +05:30 committed by GitHub
parent 5df8059037
commit d96dcb1797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -172,7 +172,9 @@ struct RawConfig {
vision_config: Option<VisionConfig>, vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>, is_encoder_decoder: Option<bool>,
#[serde(rename = "num_experts_per_tok")] #[serde(rename = "num_experts_per_tok")]
experts: Option<usize>, num_experts_per_token: Option<usize>,
#[serde(rename = "n_shared_experts")]
num_shared_experts: Option<usize>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -196,7 +198,8 @@ struct Config {
model_type: Option<String>, model_type: Option<String>,
vision_config: Option<VisionConfig>, vision_config: Option<VisionConfig>,
is_encoder_decoder: bool, is_encoder_decoder: bool,
experts: Option<usize>, num_experts_per_token: usize,
num_shared_experts: usize,
} }
impl Config { impl Config {
@ -210,11 +213,9 @@ impl Config {
let num_kv_heads = self.num_kv_heads? as u64; let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64; let head_dim = self.head_dim? as u64;
let hidden_size = self.hidden_size? as u64; let hidden_size = self.hidden_size? as u64;
let intermediate_size = if let Some(experts) = self.experts { let intermediate_size = (self.intermediate_size?
(self.intermediate_size? * experts) as u64 * (self.num_experts_per_token + self.num_shared_experts))
} else { as u64;
self.intermediate_size? as u64
};
let num_layers = self.num_layers? as u64; let num_layers = self.num_layers? as u64;
let q_flops = 2 * num_heads * head_dim * hidden_size; let q_flops = 2 * num_heads * head_dim * hidden_size;
@ -257,7 +258,8 @@ impl From<RawConfig> for Config {
let model_type = other.model_type; let model_type = other.model_type;
let vision_config = other.vision_config; let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
let experts = other.experts; let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
let num_shared_experts = other.num_shared_experts.unwrap_or(0);
Config { Config {
max_position_embeddings, max_position_embeddings,
quantize, quantize,
@ -270,7 +272,8 @@ impl From<RawConfig> for Config {
num_kv_heads, num_kv_heads,
intermediate_size, intermediate_size,
num_layers, num_layers,
experts, num_experts_per_token,
num_shared_experts,
} }
} }
} }
@ -1547,6 +1550,7 @@ impl ComputeType {
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
"nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)), "nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)),
"nvidia-a100" => Some(312 * 10u64.pow(12)), "nvidia-a100" => Some(312 * 10u64.pow(12)),
card => { card => {
tracing::warn!("Unkown compute for card {card}"); tracing::warn!("Unkown compute for card {card}");