From 3f3a1a6a66adaa88b5f63862cfdf69d4653744c0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 29 Apr 2024 17:23:40 +0200 Subject: [PATCH] Better graceful shutdown. (#1827) Fixes # (issue) - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 5 +---- server/text_generation_server/server.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a27332a7..6220a468 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -687,8 +687,7 @@ fn shard_manager( // We received a shutdown signal if shutdown.load(Ordering::SeqCst) { - terminate("Shard", p, Duration::from_secs(30)).unwrap(); - tracing::info!("Shard terminated"); + terminate("shard", p, Duration::from_secs(90)).unwrap(); return; } @@ -1249,7 +1248,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); tracing::info!("Waiting for {process_name} to gracefully shutdown"); - while terminate_time.elapsed() < timeout { if let Some(status) = process.try_wait()? { tracing::info!("{process_name} terminated"); @@ -1257,7 +1255,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R } sleep(Duration::from_millis(100)); } - tracing::info!("Killing {process_name}"); process.kill()?; diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4a07733a..f52d801c 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -5,6 +5,7 @@ import os import sys import torch import time +import signal from grpc import aio from loguru import logger @@ -20,6 +21,21 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +signal_handler = SignalHandler() + + class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, @@ -201,11 +217,8 @@ def serve( logger.info("Server started at {}".format(local_url)) - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) asyncio.run( serve_inner(