From b562680be44220c9ea9b4d1d99fe9295890be711 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:13:32 +0200 Subject: [PATCH] wip --- Cargo.lock | 1 + router/Cargo.toml | 1 + router/src/infer/mod.rs | 74 ++--- router/src/infer/v2/mod.rs | 2 +- router/src/infer/v2/scheduler.rs | 8 +- router/src/infer/v3/backend.rs | 500 +++++++++++++++++++++++++++++++ router/src/infer/v3/mod.rs | 140 ++++++++- router/src/lib.rs | 19 +- router/src/server.rs | 166 ++-------- 9 files changed, 687 insertions(+), 224 deletions(-) create mode 100644 router/src/infer/v3/backend.rs diff --git a/Cargo.lock b/Cargo.lock index b9bd7363..0ec85025 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3602,6 +3602,7 @@ name = "text-generation-router" version = "2.0.5-dev0" dependencies = [ "async-stream", + "async-trait", "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", diff --git a/router/Cargo.toml b/router/Cargo.toml index 5bf4c00c..bf96ab91 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -15,6 +15,7 @@ name = "text-generation-router" path = "src/main.rs" [dependencies] +async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 8eadcabb..a1aedadf 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,24 +1,18 @@ -mod health; -pub(crate) mod v2; +// pub(crate) mod v2; pub(crate) mod v3; mod chat_template; -mod tool_grammar; - -pub(crate) use health::HealthCheck; +pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, + ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, + HubTokenizerConfig, Message, PrefillToken, Token, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; +use crate::{GrammarType}; use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; -use minijinja_contrib::pycompat; - -use serde_json::{json, Map, Value}; -use std::collections::HashMap; +use minijinja::{ErrorKind}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; @@ -26,13 +20,16 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; use chat_template::ChatTemplate; +use async_trait::async_trait; -pub(crate) trait Scheduler { +#[async_trait] +pub(crate) trait Backend { fn schedule( &self, request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result; + ) -> Result>, InferError>; + + async fn health(&self, current_health: bool) -> bool; } /// Inference struct @@ -90,15 +87,7 @@ impl Infer { #[instrument(skip_all)] pub(crate) async fn generate_stream<'a>( &'a self, - request: GenerateRequest, - ) -> Result< - ( - OwnedSemaphorePermit, - u32, // input_length - impl Stream> + 'a, - ), - InferError, - > { + request: GenerateRequest) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -118,35 +107,11 @@ impl Infer { })?; let input_length = valid_request.input_length; - let mut generation_stream = self + let generation_stream = self .backend - .schedule(valid_request) - .map_err(InferError::Backend)?; + .schedule(valid_request)?; - let stream = stream! { - while let Some(generation) = generation_stream.next().await { - self.backend_health.store(generation.is_ok(), Ordering::SeqCst); - - yield generation.map(|generation| match generation { - types::TokenStreamResponse::Prefill(prefill_tokens) => InferStreamResponse::Prefill( - prefill_tokens.into_iter().map(PrefillToken::from).collect() - ), - types::TokenStreamResponse::Intermediate { token, top_tokens } => InferStreamResponse::Intermediate { - token: Token::from(token), - top_tokens: top_tokens.into_iter().map(Token::from).collect(), - }, - types::TokenStreamResponse::End { token, top_tokens, generated_text, start, queued } => InferStreamResponse::End { - token: Token::from(token), - top_tokens: top_tokens.into_iter().map(Token::from).collect(), - generated_text, - start, - queued, - } - }).map_err(InferError::GenerationError) - } - }; - - Ok((permit, input_length, stream)) + Ok((permit, input_length, generation_stream)) } /// Tokenizer the input @@ -363,10 +328,8 @@ pub(crate) struct InferResponse { #[derive(Debug, Error)] pub enum InferError { - #[error("Request failed during scheduling: {0}")] - Backend(BackendError), #[error("Request failed during generation: {0}")] - GenerationError(BackendError), + GenerationError(String), #[error("Model is overloaded")] Overloaded(#[from] TryAcquireError), #[error("Input validation error: {0}")] @@ -382,7 +345,6 @@ pub enum InferError { impl InferError { pub(crate) fn error_type(&self) -> &str { match self { - InferError::Backend(_) => "backend", InferError::GenerationError(_) => "generation", InferError::Overloaded(_) => "overloaded", InferError::ValidationError(_) => "validation", diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index 8b4f6bab..6a91a433 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,4 +1,4 @@ mod queue; mod scheduler; -pub(crate) use scheduler::SchedulerV2; +pub(crate) use scheduler::BackendV2; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index ba6f520d..6123c7ac 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, + GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Backend, }; use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; @@ -18,14 +18,14 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -pub(crate) struct SchedulerV2 { +pub(crate) struct BackendV2 { /// Request queue queue: Queue, /// Notify batcher on queue appends batching_task_notifier: Arc, } -impl SchedulerV2 { +impl BackendV2 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, @@ -62,7 +62,7 @@ impl SchedulerV2 { } } -impl Scheduler for SchedulerV2 { +impl Backend for BackendV2 { #[instrument(skip_all)] fn schedule( &self, diff --git a/router/src/infer/v3/backend.rs b/router/src/infer/v3/backend.rs new file mode 100644 index 00000000..6cba6e52 --- /dev/null +++ b/router/src/infer/v3/backend.rs @@ -0,0 +1,500 @@ +/// Batching and inference logic +use crate::infer::v3::queue::{Entry, Queue}; +use crate::infer::{ + GeneratedText, InferError, InferStreamResponse, Backend, +}; +use crate::validation::ValidGenerateRequest; +use crate::{FinishReason, PrefillToken, Token}; +use nohash_hasher::IntMap; +use std::sync::{ + Arc, +}; +use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; +use text_generation_client::{ClientError, Health}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; +use async_trait::async_trait; + +pub(crate) struct BackendV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + let queue = Queue::new( + requires_padding, + 16, + window_size, + speculate, + max_batch_total_tokens, + ); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = request.input_length; + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok( + UnboundedReceiverStream::new(response_rx), + ) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size", batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + } else { + metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size", 0.0); + metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + } + metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); + metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: text_generation_client::v3::GeneratedText) -> Self { + let v3_finish_reason = + text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + text_generation_client::v3::FinishReason::Length => FinishReason::Length, + text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs index f9effab8..0f7f4fdb 100644 --- a/router/src/infer/v3/mod.rs +++ b/router/src/infer/v3/mod.rs @@ -1,5 +1,141 @@ mod block_allocator; mod queue; -mod scheduler; +mod backend; -pub(crate) use scheduler::SchedulerV3; +use futures_util::TryFutureExt; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; +pub(crate) use backend::BackendV3; +use text_generation_client::ClientError; +use text_generation_client::v3::ShardedClient; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +pub async fn connect_backend( + max_input_tokens: usize, max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV3, BackendInfo), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub(crate) enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} \ No newline at end of file diff --git a/router/src/lib.rs b/router/src/lib.rs index 126726c6..80ff23d5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -135,12 +135,13 @@ pub struct Info { pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] pub model_sha: Option, - #[schema(example = "torch.float16")] - pub model_dtype: String, - #[schema(example = "cuda")] - pub model_device_type: String, + // #[schema(example = "torch.float16")] + // pub model_dtype: String, + // #[schema(example = "cuda")] + // pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, + /// Router Parameters #[schema(example = "128")] pub max_concurrent_requests: usize, @@ -152,18 +153,12 @@ pub struct Info { pub max_input_tokens: usize, #[schema(example = "2048")] pub max_total_tokens: usize, - #[schema(example = "1.2")] - pub waiting_served_ratio: f32, - #[schema(example = "32000")] - pub max_batch_total_tokens: u32, - #[schema(example = "20")] - pub max_waiting_tokens: usize, - #[schema(nullable = true, example = "null")] - pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, #[schema(example = "32")] pub max_client_batch_size: usize, + + /// Router Info #[schema(example = "text-generation-router")] pub router: &'static str, diff --git a/router/src/server.rs b/router/src/server.rs index 8e48f586..3b53fff4 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,9 +1,8 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::v2::SchedulerV2; -use crate::infer::v3::SchedulerV3; -use crate::infer::{HealthCheck, Scheduler}; -use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +use crate::infer::v3::{connect_backend, V3Error}; +use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend}; +use crate::infer::tool_grammar::ToolGrammar; #[cfg(feature = "kserve")] use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, @@ -1498,138 +1497,8 @@ pub async fn run( // Create state // Open connection, get model info and warmup - let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( - Arc, - HealthCheck, - ShardInfo, - u32, - ) = { - // Helper function to check both v2 and v3 - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { - match max_supported_batch_total_tokens { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( - 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), - ); - tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ - Attention models." - ); - tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(WebServerError::NotEnoughMemory(max_total_tokens)); - } - - Ok(max_supported_batch_total_tokens) - } - } - }; - - let generation_health = Arc::new(AtomicBool::new(false)); - - match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { - Ok(mut sharded_client) => { - // server is running on v3 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV3::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V3"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - Err(_) => { - let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - - // server is running on v2 - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?, - )?; - - let health_ext = - HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = Arc::new(SchedulerV2::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - )); - tracing::info!("Using scheduler V2"); - - (scheduler, health_ext, shard_info, max_batch_total_tokens) - } - } - }; - tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?; + // tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); let validation = Validation::new( validation_workers, @@ -1644,7 +1513,7 @@ pub async fn run( ); let infer = Infer::new( - scheduler, + backend, validation, max_concurrent_requests, tokenizer_config, @@ -1681,8 +1550,8 @@ pub async fn run( let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); // Speculated tokens buckets - let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); - let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); + // let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); + // let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); // Prometheus handler let builder = PrometheusBuilder::new() @@ -1695,9 +1564,9 @@ pub async fn run( .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .unwrap() .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) - .unwrap() - .set_buckets_for_metric(skipped_matcher, &skipped_buckets) .unwrap(); + // .set_buckets_for_metric(skipped_matcher, &skipped_buckets) + // .unwrap(); let prom_handle = builder .install_recorder() .expect("failed to install metrics recorder"); @@ -1713,18 +1582,18 @@ pub async fn run( let info = Info { model_id: model_info.model_id, model_sha: model_info.sha, - model_dtype: shard_info.dtype, - model_device_type: shard_info.device_type, + // model_dtype: shard_info.dtype, + // model_device_type: shard_info.device_type, model_pipeline_tag: model_info.pipeline_tag, max_concurrent_requests, max_best_of, max_stop_sequences, max_input_tokens, max_total_tokens, - waiting_served_ratio, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, + // waiting_served_ratio, + // max_batch_total_tokens, + // max_waiting_tokens, + // max_batch_size, validation_workers, max_client_batch_size, router: env!("CARGO_PKG_NAME"), @@ -1858,7 +1727,6 @@ pub async fn run( // add layers after routes app = app .layer(Extension(info)) - .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(compute_type)) @@ -1960,7 +1828,7 @@ impl From for Event { #[derive(Debug, Error)] pub enum WebServerError { #[error("Backend error: {0}")] - Backend(#[from] BackendError), + Backend(#[from] V3Error), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), }