From ee9178fb8b156ef81fc9831ee08e19c206f84ff2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 Jan 2025 22:02:53 +0100 Subject: [PATCH] Small modifications. --- launcher/src/main.rs | 9 +++++++-- router/src/config.rs | 2 ++ server/text_generation_server/utils/dist.py | 2 ++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c6f6b6e9..05ed0202 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1635,6 +1635,7 @@ enum Gpu { A40, H100, A100, + H200, Unknown(String), } @@ -1661,6 +1662,7 @@ impl From<&str> for Gpu { "nvidia-a100-sxm4-40gb" => Gpu::A100, "nvidia-a100-80gb-pcie" => Gpu::A100, "nvidia-a100" => Gpu::A100, + "nvidia-h200" => Gpu::H200, card => Gpu::Unknown(card.to_string()), } } @@ -1678,6 +1680,7 @@ impl std::fmt::Display for Gpu { Gpu::A40 => write!(f, "nvidia-a40"), Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"), Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"), + Gpu::H200 => write!(f, "nvida-h200"), Gpu::Unknown(card) => write!(f, "{}", card), } } @@ -1702,11 +1705,13 @@ impl ComputeType { // https://www.nvidia.com/en-us/data-center/a40/ // https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf Gpu::A40 => Some(149 * 10u64.pow(12)), + // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + Gpu::A100 => Some(312 * 10u64.pow(12)), // https://www.nvidia.com/en-us/data-center/h100/ // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf Gpu::H100 => Some(900 * 10u64.pow(12)), - // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf - Gpu::A100 => Some(312 * 10u64.pow(12)), + // https://www.nvidia.com/en-us/data-center/h200/ + Gpu::H200 => Some(989 * 10u64.pow(12)), Gpu::Unknown(card) => { tracing::warn!("Unkown compute for card {card}"); None diff --git a/router/src/config.rs b/router/src/config.rs index 4d5fcfa0..a1ac107a 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -224,6 +224,8 @@ pub enum Config { Qwen2, Opt, T5, + DeepseekV2, + DeepseekV3, } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 1b766ddf..613c4784 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -81,12 +81,14 @@ def initialize_torch_distributed(): pg_options=options, ) else: + device = torch.device(f"cuda:{RANK}") torch.distributed.init_process_group( backend=backend, world_size=WORLD_SIZE, rank=RANK, timeout=timedelta(seconds=120), pg_options=options, + device_id=device, ) else: logger.warning("torch.distributed is already initialized.")