From eade7377140a680a79bd2ce3f2d486314cf5a9b9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 29 Apr 2024 17:23:40 +0200 Subject: [PATCH] Better graceful shutdown. (#1827) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 6 +----- server/text_generation_server/server.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 28226fb4..ca6aa8dd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -683,9 +683,7 @@ fn shard_manager( // We received a shutdown signal if shutdown.load(Ordering::SeqCst) { - p.kill().unwrap(); - let _ = p.wait(); - tracing::info!("Shard terminated"); + terminate("shard", p, Duration::from_secs(90)).unwrap(); return; } @@ -1245,7 +1243,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); tracing::info!("Waiting for {process_name} to gracefully shutdown"); - while terminate_time.elapsed() < timeout { if let Some(status) = process.try_wait()? { tracing::info!("{process_name} terminated"); @@ -1253,7 +1250,6 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R } sleep(Duration::from_millis(100)); } - tracing::info!("Killing {process_name}"); process.kill()?; diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 495c2c0c..158966e3 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -2,6 +2,7 @@ import asyncio import os import torch import time +import signal from grpc import aio from loguru import logger @@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +signal_handler = SignalHandler() + + class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, @@ -231,11 +247,8 @@ def serve( logger.info("Server started at {}".format(local_url)) - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) asyncio.run( serve_inner(