mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add shutdown procedure
This commit is contained in:
parent
b3cc379550
commit
7fa81a05b0
@ -8,7 +8,7 @@ use text_generation_client::{Batch, Request};
|
|||||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
|
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, Span};
|
||||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
|
|
||||||
/// Queue entry
|
/// Queue entry
|
||||||
@ -49,7 +49,6 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
/// Append an entry to the queue
|
||||||
#[instrument(skip(self))]
|
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
// Send append command to the background task managing the state
|
// Send append command to the background task managing the state
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
@ -59,7 +58,6 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
#[instrument(skip(self))]
|
|
||||||
pub(crate) async fn next_batch(
|
pub(crate) async fn next_batch(
|
||||||
&self,
|
&self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
|
@ -391,6 +391,7 @@ async fn shutdown_signal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tracing::info!("signal received, starting graceful shutdown");
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<i32> for FinishReason {
|
impl From<i32> for FinishReason {
|
||||||
|
@ -12,6 +12,7 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMBatch(Batch):
|
class CausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
|
@ -163,6 +163,7 @@ def initialize_torch_distributed():
|
|||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from torch.distributed import ProcessGroupNCCL
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
# Set the device id.
|
# Set the device id.
|
||||||
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
||||||
device = rank % torch.cuda.device_count()
|
device = rank % torch.cuda.device_count()
|
||||||
@ -181,7 +182,7 @@ def initialize_torch_distributed():
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
pg_options=options
|
pg_options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.distributed.group.WORLD, rank, world_size
|
return torch.distributed.group.WORLD, rank, world_size
|
||||||
|
Loading…
Reference in New Issue
Block a user