Small modifications.

This commit is contained in:
Nicolas Patry 2025-01-29 22:02:53 +01:00
parent f190bc1d7a
commit ee9178fb8b
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 11 additions and 2 deletions

View File

@ -1635,6 +1635,7 @@ enum Gpu {
A40, A40,
H100, H100,
A100, A100,
H200,
Unknown(String), Unknown(String),
} }
@ -1661,6 +1662,7 @@ impl From<&str> for Gpu {
"nvidia-a100-sxm4-40gb" => Gpu::A100, "nvidia-a100-sxm4-40gb" => Gpu::A100,
"nvidia-a100-80gb-pcie" => Gpu::A100, "nvidia-a100-80gb-pcie" => Gpu::A100,
"nvidia-a100" => Gpu::A100, "nvidia-a100" => Gpu::A100,
"nvidia-h200" => Gpu::H200,
card => Gpu::Unknown(card.to_string()), card => Gpu::Unknown(card.to_string()),
} }
} }
@ -1678,6 +1680,7 @@ impl std::fmt::Display for Gpu {
Gpu::A40 => write!(f, "nvidia-a40"), Gpu::A40 => write!(f, "nvidia-a40"),
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"), Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"), Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
Gpu::H200 => write!(f, "nvida-h200"),
Gpu::Unknown(card) => write!(f, "{}", card), Gpu::Unknown(card) => write!(f, "{}", card),
} }
} }
@ -1702,11 +1705,13 @@ impl ComputeType {
// https://www.nvidia.com/en-us/data-center/a40/ // https://www.nvidia.com/en-us/data-center/a40/
// https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf // https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
Gpu::A40 => Some(149 * 10u64.pow(12)), Gpu::A40 => Some(149 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
Gpu::A100 => Some(312 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/h100/ // https://www.nvidia.com/en-us/data-center/h100/
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
Gpu::H100 => Some(900 * 10u64.pow(12)), Gpu::H100 => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf // https://www.nvidia.com/en-us/data-center/h200/
Gpu::A100 => Some(312 * 10u64.pow(12)), Gpu::H200 => Some(989 * 10u64.pow(12)),
Gpu::Unknown(card) => { Gpu::Unknown(card) => {
tracing::warn!("Unkown compute for card {card}"); tracing::warn!("Unkown compute for card {card}");
None None

View File

@ -224,6 +224,8 @@ pub enum Config {
Qwen2, Qwen2,
Opt, Opt,
T5, T5,
DeepseekV2,
DeepseekV3,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]

View File

@ -81,12 +81,14 @@ def initialize_torch_distributed():
pg_options=options, pg_options=options,
) )
else: else:
device = torch.device(f"cuda:{RANK}")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=120), timeout=timedelta(seconds=120),
pg_options=options, pg_options=options,
device_id=device,
) )
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")