diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 495c2c0c..158966e3 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -2,6 +2,7 @@ import asyncio import os import torch import time +import signal from grpc import aio from loguru import logger @@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +signal_handler = SignalHandler() + + class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, @@ -231,11 +247,8 @@ def serve( logger.info("Server started at {}".format(local_url)) - try: - await server.wait_for_termination() - except KeyboardInterrupt: - logger.info("Signal received. Shutting down") - await server.stop(0) + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) asyncio.run( serve_inner(