diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f4f2c533..01c60bb4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -435,7 +435,6 @@ fn shard_manager( quantize: Option, speculate: Option, dtype: Option, - max_total_tokens: usize, trust_remote_code: bool, uds_path: String, rank: usize, @@ -451,6 +450,8 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, + max_total_tokens: usize, + max_batch_size: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -515,6 +516,7 @@ fn shard_manager( (Some(scaling), Some(factor)) => Some((scaling, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)), }; + // OpenTelemetry if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); @@ -527,9 +529,6 @@ fn shard_manager( // Remove LOG_LEVEL if present envs.retain(|(name, _)| name != "LOG_LEVEL"); - // Max total tokens - envs.push(("MAX_TOTAL_TOKENS".into(), max_total_tokens.to_string().into())); - // Torch Distributed Env vars if world_size == 1 { envs.push(("RANK".into(), rank.to_string().into())); @@ -572,6 +571,14 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + if let Some(max_batch_size) = max_batch_size { + envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -975,13 +982,13 @@ fn spawn_shards( num_shard: usize, args: &Args, cuda_graphs: Vec, + max_total_tokens: usize, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, status_receiver: &mpsc::Receiver, status_sender: mpsc::Sender, running: Arc, - max_total_tokens: usize, ) -> Result<(), LauncherError> { // Start shard processes for rank in 0..1 { @@ -1007,6 +1014,7 @@ fn spawn_shards( let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; + let max_batch_size = args.max_batch_size; thread::spawn(move || { shard_manager( model_id, @@ -1014,7 +1022,6 @@ fn spawn_shards( quantize, speculate, dtype, - max_total_tokens, trust_remote_code, uds_path, rank, @@ -1030,6 +1037,8 @@ fn spawn_shards( cuda_memory_fraction, rope_scaling, rope_factor, + max_total_tokens, + max_batch_size, otlp_endpoint, status_sender, shutdown, @@ -1494,13 +1503,13 @@ fn main() -> Result<(), LauncherError> { num_shard, &args, cuda_graphs, + max_total_tokens, shutdown.clone(), &shutdown_receiver, shutdown_sender, &status_receiver, status_sender, running.clone(), - max_total_tokens, )?; // We might have received a termination signal