From 7fa81a05b01fe8e32029e4911a1675c21591b96b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 9 Feb 2023 19:57:53 +0100 Subject: [PATCH] add shutdown procedure --- router/src/queue.rs | 4 +--- router/src/server.rs | 1 + server/text_generation/models/causal_lm.py | 1 + server/text_generation/utils.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/router/src/queue.rs b/router/src/queue.rs index cd9bb450..8de76668 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -8,7 +8,7 @@ use text_generation_client::{Batch, Request}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; -use tracing::{info_span, instrument, Span}; +use tracing::{info_span, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// Queue entry @@ -49,7 +49,6 @@ impl Queue { } /// Append an entry to the queue - #[instrument(skip(self))] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state // Unwrap is safe here @@ -59,7 +58,6 @@ impl Queue { } // Get the next batch - #[instrument(skip(self))] pub(crate) async fn next_batch( &self, min_size: Option, diff --git a/router/src/server.rs b/router/src/server.rs index 628911ca..389ad48e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -391,6 +391,7 @@ async fn shutdown_signal() { } tracing::info!("signal received, starting graceful shutdown"); + opentelemetry::global::shutdown_tracer_provider(); } impl From for FinishReason { diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 943a6726..0cf28705 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -12,6 +12,7 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling tracer = trace.get_tracer(__name__) + @dataclass class CausalLMBatch(Batch): batch_id: int diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 5f0a2119..abdee9dc 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -163,6 +163,7 @@ def initialize_torch_distributed(): if torch.cuda.is_available(): from torch.distributed import ProcessGroupNCCL + # Set the device id. assert world_size <= torch.cuda.device_count(), "Each process is one gpu" device = rank % torch.cuda.device_count() @@ -181,7 +182,7 @@ def initialize_torch_distributed(): world_size=world_size, rank=rank, timeout=timedelta(seconds=60), - pg_options=options + pg_options=options, ) return torch.distributed.group.WORLD, rank, world_size