diff --git a/launcher/src/main.rs b/launcher/src/main.rs index eb47f65e..db92d57c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -591,8 +591,7 @@ fn shard_manager( // We received a shutdown signal if shutdown.load(Ordering::SeqCst) { - p.kill().unwrap(); - let _ = p.wait(); + terminate("Shard", p, Duration::from_secs(30)).unwrap(); tracing::info!("Shard terminated"); return; } @@ -923,7 +922,7 @@ fn spawn_shards( drop(shutdown_sender); // Wait for shard to start - let mut shard_ready = 0; + let mut shard_ready = 0; while running.load(Ordering::SeqCst) { match status_receiver.try_recv() { Ok(ShardStatus::Ready) => { diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index ca293051..8e7a6382 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -1,4 +1,6 @@ import os +import psutil +import signal import sys import typer @@ -76,7 +78,39 @@ def serve( sys.stdout.flush() sys.stderr.flush() with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc: - proc.wait() + do_terminate = False + current_handler = signal.getsignal(signal.SIGTERM) + def terminate_handler(sig, frame): + nonlocal do_terminate + do_terminate = True + if callable(current_handler): + current_handler(sig, frame) + + signal.signal(signal.SIGTERM, terminate_handler) + + finished = False + while not finished: + try: + if do_terminate: + parent = psutil.Process(proc.pid) + all_procs = parent.children(recursive=True) + [parent] + for p in all_procs: + try: + p.terminate() + except psutil.NoSuchProcess: + pass + _, alive = psutil.wait_procs(all_procs, timeout=30) + for p in alive: + p.kill() + + do_terminate = False + + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + pass + else: + finished = True + sys.stdout.flush() sys.stderr.flush() if proc.returncode != 0: diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 533d35f0..61376640 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -408,7 +408,7 @@ class CausalLM(Model): } world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK"), 0) + rank = int(os.getenv("RANK", "0")) self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"