Deepspeed terminate (#11)

This commit is contained in:
mrs303 2024-01-17 09:57:03 +01:00 committed by GitHub
parent c459c86f88
commit 8523f7ef64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 5 deletions

View File

@ -591,8 +591,7 @@ fn shard_manager(
// We received a shutdown signal
if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap();
let _ = p.wait();
terminate("Shard", p, Duration::from_secs(30)).unwrap();
tracing::info!("Shard terminated");
return;
}
@ -923,7 +922,7 @@ fn spawn_shards(
drop(shutdown_sender);
// Wait for shard to start
let mut shard_ready = 0;
let mut shard_ready = 0;
while running.load(Ordering::SeqCst) {
match status_receiver.try_recv() {
Ok(ShardStatus::Ready) => {

View File

@ -1,4 +1,6 @@
import os
import psutil
import signal
import sys
import typer
@ -76,7 +78,39 @@ def serve(
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
proc.wait()
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)
signal.signal(signal.SIGTERM, terminate_handler)
finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()
do_terminate = False
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True
sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:

View File

@ -408,7 +408,7 @@ class CausalLM(Model):
}
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK"), 0)
rank = int(os.getenv("RANK", "0"))
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"