diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0848dd9a..a698ebe3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -47,6 +47,8 @@ struct Args { #[clap(long, env)] weights_cache_override: Option, #[clap(long, env)] + disable_custom_kernels: bool, + #[clap(long, env)] json_output: bool, #[clap(long, env)] otlp_endpoint: Option, @@ -69,6 +71,7 @@ fn main() -> ExitCode { master_port, huggingface_hub_cache, weights_cache_override, + disable_custom_kernels, json_output, otlp_endpoint, } = Args::parse(); @@ -241,6 +244,7 @@ fn main() -> ExitCode { master_port, huggingface_hub_cache, weights_cache_override, + disable_custom_kernels, otlp_endpoint, status_sender, shutdown, @@ -405,6 +409,7 @@ fn shard_manager( master_port: usize, huggingface_hub_cache: Option, weights_cache_override: Option, + disable_custom_kernels: bool, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc>, @@ -473,6 +478,11 @@ fn shard_manager( )); }; + // If disable_custom_kernels is true, pass it to the shard as an env var + if disable_custom_kernels { + env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) + } + // If the NCCL_SHM_DISABLE env var is set, pass it to the shard // needed when running NCCL inside a docker container and when you can't increase shm size if let Ok(nccl_shm_disalbe) = env::var("NCCL_SHM_DISABLE") {