diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index d7ae11d5..d9056e41 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -28,8 +28,8 @@ impl Env { } } - pub fn is_hpu_device(&self) -> bool { - self.hpu_env != "N/A" + pub fn should_start_a_single_hpu_shard(&self) -> bool { + self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged") } } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 23288b20..86d8714a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1559,10 +1559,7 @@ fn spawn_shards( ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..num_shard { - if rank != 0 - && env_runtime::Env::new().is_hpu_device() - && std::env::var("ATTENTION").as_deref() != Ok("paged") - { + if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); break; } @@ -1642,7 +1639,7 @@ fn spawn_shards( if shard_ready == num_shard { break; } - if env_runtime::Env::new().is_hpu_device() { + if env_runtime::Env::new().should_start_a_single_hpu_shard() { tracing::info!("HPU detected, shard is ready"); break; }