mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 19:02:09 +00:00
Deepspeed terminate (#11)
This commit is contained in:
parent
c459c86f88
commit
8523f7ef64
@ -591,8 +591,7 @@ fn shard_manager(
|
||||
|
||||
// We received a shutdown signal
|
||||
if shutdown.load(Ordering::SeqCst) {
|
||||
p.kill().unwrap();
|
||||
let _ = p.wait();
|
||||
terminate("Shard", p, Duration::from_secs(30)).unwrap();
|
||||
tracing::info!("Shard terminated");
|
||||
return;
|
||||
}
|
||||
@ -923,7 +922,7 @@ fn spawn_shards(
|
||||
drop(shutdown_sender);
|
||||
|
||||
// Wait for shard to start
|
||||
let mut shard_ready = 0;
|
||||
let mut shard_ready = 0;
|
||||
while running.load(Ordering::SeqCst) {
|
||||
match status_receiver.try_recv() {
|
||||
Ok(ShardStatus::Ready) => {
|
||||
|
@ -1,4 +1,6 @@
|
||||
import os
|
||||
import psutil
|
||||
import signal
|
||||
import sys
|
||||
import typer
|
||||
|
||||
@ -76,7 +78,39 @@ def serve(
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||
proc.wait()
|
||||
do_terminate = False
|
||||
current_handler = signal.getsignal(signal.SIGTERM)
|
||||
def terminate_handler(sig, frame):
|
||||
nonlocal do_terminate
|
||||
do_terminate = True
|
||||
if callable(current_handler):
|
||||
current_handler(sig, frame)
|
||||
|
||||
signal.signal(signal.SIGTERM, terminate_handler)
|
||||
|
||||
finished = False
|
||||
while not finished:
|
||||
try:
|
||||
if do_terminate:
|
||||
parent = psutil.Process(proc.pid)
|
||||
all_procs = parent.children(recursive=True) + [parent]
|
||||
for p in all_procs:
|
||||
try:
|
||||
p.terminate()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||||
for p in alive:
|
||||
p.kill()
|
||||
|
||||
do_terminate = False
|
||||
|
||||
proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
else:
|
||||
finished = True
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
if proc.returncode != 0:
|
||||
|
@ -408,7 +408,7 @@ class CausalLM(Model):
|
||||
}
|
||||
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK"), 0)
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user