mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 19:02:09 +00:00
Deepspeed terminate (#11)
This commit is contained in:
parent
c459c86f88
commit
8523f7ef64
@ -591,8 +591,7 @@ fn shard_manager(
|
|||||||
|
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
p.kill().unwrap();
|
terminate("Shard", p, Duration::from_secs(30)).unwrap();
|
||||||
let _ = p.wait();
|
|
||||||
tracing::info!("Shard terminated");
|
tracing::info!("Shard terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -923,7 +922,7 @@ fn spawn_shards(
|
|||||||
drop(shutdown_sender);
|
drop(shutdown_sender);
|
||||||
|
|
||||||
// Wait for shard to start
|
// Wait for shard to start
|
||||||
let mut shard_ready = 0;
|
let mut shard_ready = 0;
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
match status_receiver.try_recv() {
|
match status_receiver.try_recv() {
|
||||||
Ok(ShardStatus::Ready) => {
|
Ok(ShardStatus::Ready) => {
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
import psutil
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
@ -76,7 +78,39 @@ def serve(
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||||
proc.wait()
|
do_terminate = False
|
||||||
|
current_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
def terminate_handler(sig, frame):
|
||||||
|
nonlocal do_terminate
|
||||||
|
do_terminate = True
|
||||||
|
if callable(current_handler):
|
||||||
|
current_handler(sig, frame)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, terminate_handler)
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
try:
|
||||||
|
if do_terminate:
|
||||||
|
parent = psutil.Process(proc.pid)
|
||||||
|
all_procs = parent.children(recursive=True) + [parent]
|
||||||
|
for p in all_procs:
|
||||||
|
try:
|
||||||
|
p.terminate()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||||||
|
for p in alive:
|
||||||
|
p.kill()
|
||||||
|
|
||||||
|
do_terminate = False
|
||||||
|
|
||||||
|
proc.wait(timeout=3)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
finished = True
|
||||||
|
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
|
@ -408,7 +408,7 @@ class CausalLM(Model):
|
|||||||
}
|
}
|
||||||
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
rank = int(os.getenv("RANK"), 0)
|
rank = int(os.getenv("RANK", "0"))
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
|
||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "true").lower() == "true"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user