diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 28226fb4..ca6aa8dd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -683,9 +683,7 @@ fn shard_manager( // We received a shutdown signal if shutdown.load(Ordering::SeqCst) { - p.kill().unwrap(); - let _ = p.wait(); - tracing::info!("Shard terminated"); + terminate("shard", p, Duration::from_secs(90)).unwrap(); return; } @@ -1245,7 +1243,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); tracing::info!("Waiting for {process_name} to gracefully shutdown"); - while terminate_time.elapsed() < timeout { if let Some(status) = process.try_wait()? { tracing::info!("{process_name} terminated"); @@ -1253,7 +1250,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R } sleep(Duration::from_millis(100)); } - tracing::info!("Killing {process_name}"); process.kill()?; diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 495c2c0c..158966e3 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -2,6 +2,7 @@ import asyncio import os import torch import time +import signal from grpc import aio from loguru import logger @@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +signal_handler = SignalHandler() + + class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, @@ -231,11 +247,8 @@ def serve( logger.info("Server started at {}".format(local_url)) - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) asyncio.run( serve_inner(