diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 42e33ac9..db5f4943 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -16,38 +16,18 @@ use minijinja::{Environment, ErrorKind, Template}; use serde_json::{json, Map, Value}; use std::collections::HashMap; use std::sync::{ - atomic::{AtomicBool}, Arc, }; -use text_generation_client::v2::{ShardedClient}; use thiserror::Error; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; -use tracing::{instrument, Span}; +use tracing::{instrument}; -/// Queue entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: ValidGenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: mpsc::UnboundedSender>, - /// Span that will live as long as entry - pub span: Span, - /// Temporary span used as a guard when logging inference, wait times... - pub temp_span: Option, - /// Instant when this entry was queued - pub queue_time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, -} -pub(crate) trait InferQueue { - /// Append an entry to the queue - #[instrument(skip_all)] - fn append(&self, entry: Entry); +pub(crate) trait Scheduler { + fn schedule(&self, request: ValidGenerateRequest, permit: OwnedSemaphorePermit) -> Result; } @@ -56,10 +36,8 @@ pub(crate) trait InferQueue { pub struct Infer { /// Validation validation: Validation, - /// Request queue - queue: Arc, - /// Notify batcher on queue appends - batching_task_notifier: Arc, + /// Request scheduler + scheduler: Arc, /// Chat template chat_template: Option, /// Inference limit @@ -71,37 +49,12 @@ pub struct Infer { impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - client: ShardedClient, + scheduler: Arc, validation: Validation, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, max_concurrent_requests: usize, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, tokenizer_config: HubTokenizerConfig, processor_config: HubProcessorConfig, ) -> Self { - let queue = v2::Queue::new(requires_padding, 16, window_size, speculate); - let batching_task_notifier = Arc::new(Notify::new()); - - // Spawn batching background task that contains all the inference logic - tokio::spawn(v2::batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - batching_task_notifier.clone(), - generation_health, - )); - let chat_template = tokenizer_config .chat_template .or(processor_config.chat_template) @@ -126,8 +79,7 @@ impl Infer { Self { validation, - queue: Arc::new(queue), - batching_task_notifier, + scheduler, chat_template, limit_concurrent_requests: semaphore, } @@ -157,30 +109,7 @@ impl Infer { err })?; - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = valid_request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request: valid_request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.batching_task_notifier.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) + self.scheduler.schedule(valid_request, permit) } /// Tokenizer the input diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index 0af0e5cb..f87d863c 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,5 +1,4 @@ -mod batcher; +mod scheduler; mod queue; -pub(crate) use batcher::batching_task; -pub(crate) use queue::Queue; +pub(crate) use scheduler::SchedulerV2; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index d62c8cdc..057b1804 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -1,7 +1,5 @@ -use crate::infer::{Entry, InferQueue}; -use crate::validation::{ - ValidGrammar, ValidParameters, ValidStoppingParameters, -}; +use crate::infer::{InferError, InferStreamResponse}; +use crate::validation::{ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; @@ -15,6 +13,23 @@ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; +/// Queue entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: ValidGenerateRequest, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: mpsc::UnboundedSender>, + /// Span that will live as long as entry + pub span: Span, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, + /// Instant when this entry was queued + pub queue_time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, +} + /// Request Queue #[derive(Debug, Clone)] pub(crate) struct Queue { @@ -22,19 +37,6 @@ pub(crate) struct Queue { queue_sender: mpsc::UnboundedSender, } - -impl InferQueue for Queue { - /// Append an entry to the queue - #[instrument(skip_all)] - fn append(&self, entry: Entry) { - // Send append command to the background task managing the state - // Unwrap is safe here - self.queue_sender - .send(QueueCommand::Append(Box::new(entry), Span::current())) - .unwrap(); - } -} - impl Queue { pub(crate) fn new( requires_padding: bool, @@ -57,6 +59,15 @@ impl Queue { Self { queue_sender } } + #[instrument(skip_all)] + pub(crate) fn append(&self, entry: Entry) { + // Send append command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::Append(Box::new(entry), Span::current())) + .unwrap(); + } + // Get the next batch #[instrument(skip(self))] pub(crate) async fn next_batch( diff --git a/router/src/infer/v2/batcher.rs b/router/src/infer/v2/scheduler.rs similarity index 95% rename from router/src/infer/v2/batcher.rs rename to router/src/infer/v2/scheduler.rs index 3cfc9118..f0815168 100644 --- a/router/src/infer/v2/batcher.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,7 +1,6 @@ /// Batching and inference logic -use crate::infer::Entry; -use crate::infer::v2::{Queue}; +use crate::infer::v2::queue::{Queue, Entry}; use crate::{FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ @@ -11,10 +10,86 @@ use std::sync::{ use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient}; use text_generation_client::{ClientError}; use tokio::sync::mpsc::error::SendError; -use tokio::sync::Notify; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::time::Instant; -use tracing::{info_span, instrument, Instrument}; -use crate::infer::{GeneratedText, InferError, InferStreamResponse}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; +use crate::infer::{GeneratedText, GenerateStreamResponse, InferError, InferStreamResponse, Scheduler}; +use crate::validation::ValidGenerateRequest; + +pub(crate) struct SchedulerV2 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, +} + +impl SchedulerV2 { + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + generation_health: Arc, + ) -> Self { + let queue = Queue::new(requires_padding, 16, window_size, speculate); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + generation_health, + )); + + Self { + queue, + batching_task_notifier + } + } +} + +impl Scheduler for SchedulerV2 { + #[instrument(skip_all)] + fn schedule(&self, request: ValidGenerateRequest, permit: OwnedSemaphorePermit) -> Result { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = request.input_length; + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) + } +} + /// Batching logic /// Will be launched in a background Tokio task @@ -692,10 +767,10 @@ mod tests { content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); let test_default_templates = vec![ ChatTemplateTestItem { diff --git a/router/src/server.rs b/router/src/server.rs index 2ca49bc7..2bfb9aa8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -46,6 +46,7 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; use thiserror::Error; +use crate::infer::v2::SchedulerV2; /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( @@ -1472,8 +1473,10 @@ pub async fn run( )] struct ApiDoc; + // Create state + // Open connection, get model info and warmup - let (infer, health_ext, shard_info, max_batch_total_tokens) = { + let (scheduler, health_ext, shard_info, max_batch_total_tokens) = { // Helper function to check both v2 and v3 let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { match max_supported_batch_total_tokens { @@ -1505,18 +1508,7 @@ pub async fn run( } }; - // Create state - let validation = Validation::new( - validation_workers, - tokenizer, - config, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - grammar_support, - ); + let generation_health = Arc::new(AtomicBool::new(false)); // Try to open a v3 client @@ -1546,26 +1538,36 @@ pub async fn run( tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); let health_ext = HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let infer = Infer::new( + let scheduler = SchedulerV2::new( sharded_client, - validation, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size, - max_concurrent_requests, shard_info.requires_padding, shard_info.window_size, shard_info.speculate, generation_health, - tokenizer_config, - processor_config, ); - (infer, health_ext, shard_info, max_batch_total_tokens) + (scheduler, health_ext, shard_info, max_batch_total_tokens) }; + let validation = Validation::new( + validation_workers, + tokenizer, + config, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + grammar_support, + ); + + let infer = Infer::new(Arc::new(scheduler), validation, max_concurrent_requests, tokenizer_config, processor_config); + // Duration buckets let duration_matcher = Matcher::Suffix(String::from("duration")); let n_duration_buckets = 35;