add shutdown procedure

This commit is contained in:
OlivierDehaene 2023-02-09 19:57:53 +01:00
parent b3cc379550
commit 7fa81a05b0
4 changed files with 5 additions and 4 deletions

View File

@ -8,7 +8,7 @@ use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Queue entry /// Queue entry
@ -49,7 +49,6 @@ impl Queue {
} }
/// Append an entry to the queue /// Append an entry to the queue
#[instrument(skip(self))]
pub(crate) fn append(&self, entry: Entry) { pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state // Send append command to the background task managing the state
// Unwrap is safe here // Unwrap is safe here
@ -59,7 +58,6 @@ impl Queue {
} }
// Get the next batch // Get the next batch
#[instrument(skip(self))]
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,

View File

@ -391,6 +391,7 @@ async fn shutdown_signal() {
} }
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
opentelemetry::global::shutdown_tracer_provider();
} }
impl From<i32> for FinishReason { impl From<i32> for FinishReason {

View File

@ -12,6 +12,7 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class CausalLMBatch(Batch): class CausalLMBatch(Batch):
batch_id: int batch_id: int

View File

@ -163,6 +163,7 @@ def initialize_torch_distributed():
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
# Set the device id. # Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu" assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count() device = rank % torch.cuda.device_count()
@ -181,7 +182,7 @@ def initialize_torch_distributed():
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
pg_options=options pg_options=options,
) )
return torch.distributed.group.WORLD, rank, world_size return torch.distributed.group.WORLD, rank, world_size