diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index f8a10ca2..086fc6dc 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -13,7 +13,7 @@ use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -pub struct BackendV3 { +pub struct BackendV2 { /// Request queue queue: Queue, /// Notify batcher on queue appends @@ -22,7 +22,7 @@ pub struct BackendV3 { client: ShardedClient, } -impl BackendV3 { +impl BackendV2 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, @@ -35,24 +35,20 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let prefix_caching = - std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); - let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); - let attention: String = std::env::var("ATTENTION").expect("attention env var"); - - let attention: Attention = attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); - let block_size = attention.block_size(); - - let queue = Queue::new( - requires_padding, - block_size, - prefix_caching, - window_size, - speculate, - max_batch_total_tokens, - ); + // Infer shared state + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + } else { + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 + }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic @@ -76,7 +72,7 @@ impl BackendV3 { } #[async_trait] -impl Backend for BackendV3 { +impl Backend for BackendV2 { #[instrument(skip_all)] fn schedule( &self, @@ -93,7 +89,6 @@ impl Backend for BackendV3 { temp_span: None, queue_time: Instant::now(), batch_time: None, - block_allocation: None, }); // Notify the background task that we have a new entry in the queue that needs @@ -168,15 +163,12 @@ pub(crate) async fn batching_task( None } else { // Minimum batch size - // TODO: temporarily disable to avoid incorrect deallocation + - // reallocation when using prefix caching. Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); - // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) @@ -259,13 +251,13 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") .record(timings.forward.as_secs_f64()); metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") .record(timings.decode.as_secs_f64()); metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") .record(start_time.elapsed().as_secs_f64()); metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch @@ -497,8 +489,8 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { impl From for GeneratedText { fn from(value: crate::client::GeneratedText) -> Self { - let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); - let finish_reason = match v3_finish_reason { + let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v2_finish_reason { crate::client::FinishReason::Length => FinishReason::Length, crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, crate::client::FinishReason::StopSequence => FinishReason::StopSequence, diff --git a/backends/v2/src/block_allocator.rs b/backends/v2/src/block_allocator.rs deleted file mode 100644 index 4fea172b..00000000 --- a/backends/v2/src/block_allocator.rs +++ /dev/null @@ -1,209 +0,0 @@ -use std::sync::Arc; -use tokio::sync::{mpsc, oneshot}; - -use crate::radix::RadixAllocator; - -#[derive(Debug, Clone)] -pub struct BlockAllocation { - pub allocation_id: u64, - pub blocks: Vec, - pub slots: Vec, - - /// Prefix that was cached and for which the KV does not have to - /// be recomputed. - pub prefix_len: u32, - - pub(crate) block_allocator: Option, -} - -impl Drop for BlockAllocation { - fn drop(&mut self) { - if let Some(block_allocator) = self.block_allocator.as_mut() { - block_allocator.free(self.blocks.clone(), self.allocation_id) - } - } -} - -#[derive(Debug, Clone)] -pub struct BlockAllocator { - /// Channel to communicate with the background task - block_allocator: mpsc::UnboundedSender, -} - -impl BlockAllocator { - pub(crate) fn new( - max_batch_total_tokens: u32, - block_size: u32, - prefix_caching: bool, - window_size: Option, - ) -> Self { - // Create channel - let (sender, receiver) = mpsc::unbounded_channel(); - - // Launch background queue task - tokio::spawn(block_allocator_task( - max_batch_total_tokens / block_size, - block_size, - prefix_caching, - window_size, - receiver, - )); - - Self { - block_allocator: sender, - } - } - - pub(crate) async fn allocate( - &self, - tokens: u32, - prefill_tokens: Option>>, - ) -> Option { - let (response_sender, response_receiver) = oneshot::channel(); - self.block_allocator - .send(BlockAllocatorCommand::Allocate { - tokens, - prefill_tokens, - response_sender, - }) - .unwrap(); - - response_receiver.await.unwrap().map(|mut allocation| { - allocation.block_allocator = Some(self.clone()); - allocation - }) - } - - pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { - self.block_allocator - .send(BlockAllocatorCommand::Free { - allocation_id, - blocks, - }) - .unwrap(); - } -} - -async fn block_allocator_task( - blocks: u32, - block_size: u32, - prefix_caching: bool, - window_size: Option, - mut receiver: mpsc::UnboundedReceiver, -) { - let mut allocator: Box = if prefix_caching { - Box::new(RadixAllocator::new(block_size, blocks, window_size)) - } else { - Box::new(SimpleAllocator::new(blocks, block_size, window_size)) - }; - while let Some(cmd) = receiver.recv().await { - match cmd { - BlockAllocatorCommand::Free { - blocks, - allocation_id, - } => allocator.free(blocks, allocation_id), - BlockAllocatorCommand::Allocate { - tokens, - prefill_tokens, - response_sender, - } => { - response_sender - .send(allocator.allocate(tokens, prefill_tokens)) - .unwrap(); - } - } - } -} - -#[derive(Debug)] -enum BlockAllocatorCommand { - Free { - blocks: Vec, - allocation_id: u64, - }, - Allocate { - tokens: u32, - prefill_tokens: Option>>, - response_sender: oneshot::Sender>, - }, -} - -pub trait Allocator { - fn allocate( - &mut self, - tokens: u32, - prefill_tokens: Option>>, - ) -> Option; - - fn free(&mut self, blocks: Vec, allocation_id: u64); -} -pub struct SimpleAllocator { - free_blocks: Vec, - block_size: u32, - window_size: Option, -} - -impl SimpleAllocator { - fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { - SimpleAllocator { - block_size, - // Block 0 is reserved for health checks - free_blocks: (1..blocks).collect(), - window_size, - } - } -} - -impl Allocator for SimpleAllocator { - fn allocate( - &mut self, - tokens: u32, - _prefill_tokens: Option>>, - ) -> Option { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match self.window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = core::cmp::min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + self.block_size - 1) / self.block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - if required_blocks > self.free_blocks.len() as u32 { - None - } else { - let blocks = self - .free_blocks - .split_off(self.free_blocks.len() - required_blocks as usize); - let mut slots = - Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some(BlockAllocation { - allocation_id: 0, - blocks, - slots, - prefix_len: 0, - block_allocator: None, - }) - } - } - - fn free(&mut self, blocks: Vec, _allocation_id: u64) { - self.free_blocks.extend(blocks) - } -} diff --git a/backends/v2/src/client/grpc_client.rs b/backends/v2/src/client/grpc_client.rs index 648662db..b4943521 100644 --- a/backends/v2/src/client/grpc_client.rs +++ b/backends/v2/src/client/grpc_client.rs @@ -1,11 +1,9 @@ /// Single shard Client -use crate::client::{pb, Chunk}; +use crate::client::pb; use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; -use base64::engine::general_purpose::STANDARD; -use base64::Engine; use grpc_metadata::InjectTelemetryContext; -use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; -use pb::generate::v3::*; +use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v2::*; use std::cmp::min; use std::time::Duration; use tonic::transport::{Channel, Uri}; @@ -47,7 +45,7 @@ impl Client { pub async fn service_discovery(&mut self) -> Result> { let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); let response = self.stub.service_discovery(request).await.map_err(|_| { - ClientError::Connection("Server does not support v3 interface".to_string()) + ClientError::Connection("Server does not support v2 interface".to_string()) })?; let urls = response .into_inner() @@ -119,23 +117,6 @@ impl Client { while n_tokens < max_prefill_tokens { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); - let mut input_chunks = Vec::new(); - input_chunks - .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); - if n_tokens == 0 { - input_chunks.push( - Chunk::Image(Image { - // Safe unwrap, because we control the data. - data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), - mimetype: "image/jpeg;base64".to_string(), - }) - .into(), - ); - } - - // Send stringly-typed inputs for compatibility for backends that haven't - // been updated to support chunks. - let mut inputs = String::new(); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); if n_tokens == 0 { @@ -149,16 +130,8 @@ impl Client { requests.push(Request { id: 0, inputs, - add_special_tokens: true, - input_chunks: Some(Input { - chunks: input_chunks, - }), // We truncate the input on the server side to be sure that it has the correct size truncate, - // Blocks and slots will be set on the server side if we use paged attention - blocks: vec![], - slots: vec![], - prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -180,7 +153,6 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, - adapter_id: None, }); n_tokens += max_input_length; @@ -194,8 +166,7 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: max_input_length, - max_blocks: 0, + max_tokens: 0, }; let request = tonic::Request::new(WarmupRequest { diff --git a/backends/v2/src/client/mod.rs b/backends/v2/src/client/mod.rs index 755431f4..fa9d4406 100644 --- a/backends/v2/src/client/mod.rs +++ b/backends/v2/src/client/mod.rs @@ -12,10 +12,9 @@ mod grpc_client; mod sharded_client; pub use grpc_client::Client; -pub use pb::generate::v3::{ - input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, + InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; @@ -64,13 +63,6 @@ impl From for ClientError { } } -// Small convenience re-wrapping of `Chunk`. -impl From for InputChunk { - fn from(chunk: Chunk) -> Self { - InputChunk { chunk: Some(chunk) } - } -} - static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; pub type Result = std::result::Result; diff --git a/backends/v2/src/client/sharded_client.rs b/backends/v2/src/client/sharded_client.rs index ea77a696..eccf76d5 100644 --- a/backends/v2/src/client/sharded_client.rs +++ b/backends/v2/src/client/sharded_client.rs @@ -1,13 +1,13 @@ -use crate::client::{ClientError, Result}; /// Multi shard Client +use crate::client::{ClientError, Result}; use crate::client::{Health, ShardInfo}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::InfoResponse; use crate::client::{ Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; -use crate::client::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; @@ -218,11 +218,7 @@ impl Health for ShardedClient { let liveness_request = Request { id: u64::MAX, inputs: "liveness".to_string(), - input_chunks: Some(Input { - chunks: vec![Chunk::Text("liveness".into()).into()], - }), truncate: 10, - add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, @@ -243,18 +239,12 @@ impl Health for ShardedClient { ignore_eos_token: false, }), top_n_tokens: 0, - // Block 0 is reserved for health checks - blocks: vec![0], - slots: (0..16).collect(), - prefix_len: 0, - adapter_id: None, }; let batch = Batch { id: u64::MAX, requests: vec![liveness_request], size: 1, max_tokens: 2, - max_blocks: 1, }; self.clone().prefill(batch).await?; Ok(()) diff --git a/backends/v2/src/lib.rs b/backends/v2/src/lib.rs index 77a9a11a..85c36931 100644 --- a/backends/v2/src/lib.rs +++ b/backends/v2/src/lib.rs @@ -1,11 +1,9 @@ mod backend; -pub mod block_allocator; mod client; mod queue; -pub mod radix; use crate::client::{ClientError, ShardedClient}; -pub(crate) use backend::BackendV3; +pub(crate) use backend::BackendV2; use serde::Serialize; use thiserror::Error; use utoipa::ToSchema; @@ -41,7 +39,7 @@ pub async fn connect_backend( max_batch_total_tokens: Option, max_waiting_tokens: usize, max_batch_size: Option, -) -> Result<(BackendV3, BackendInfo), V3Error> { +) -> Result<(BackendV2, BackendInfo), V2Error> { // Helper function let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { match max_supported_batch_total_tokens { @@ -65,7 +63,7 @@ pub async fn connect_backend( ); } if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(V3Error::NotEnoughMemory(max_total_tokens)); + return Err(V2Error::NotEnoughMemory(max_total_tokens)); } Ok(max_supported_batch_total_tokens) @@ -75,16 +73,16 @@ pub async fn connect_backend( let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await - .map_err(V3Error::Connection)?; + .map_err(V2Error::Connection)?; - // server is running on v3 + // server is running on v2 // Clear the cache; useful if the webserver rebooted sharded_client .clear_cache(None) .await - .map_err(V3Error::Cache)?; + .map_err(V2Error::Cache)?; // Get info from the shard - let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + let shard_info = sharded_client.info().await.map_err(V2Error::Info)?; // Warmup model tracing::info!("Warming up model"); @@ -97,7 +95,7 @@ pub async fn connect_backend( max_batch_size, ) .await - .map_err(V3Error::Warmup)?, + .map_err(V2Error::Warmup)?, )?; tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); @@ -111,7 +109,7 @@ pub async fn connect_backend( speculate: shard_info.speculate as usize, }; - let backend = BackendV3::new( + let backend = BackendV2::new( sharded_client, waiting_served_ratio, max_batch_prefill_tokens, @@ -129,7 +127,7 @@ pub async fn connect_backend( } #[derive(Debug, Error)] -pub enum V3Error { +pub enum V2Error { #[error("Unable to clear the Python model shards cache: {0}")] Cache(ClientError), #[error("Unable to connect to the Python model shards: {0}")] diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index 471ddb5a..f53d898e 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -1,6 +1,6 @@ use clap::{Parser, Subcommand}; use text_generation_router::{server, usage_stats}; -use text_generation_router_v3::{connect_backend, V3Error}; +use text_generation_router_v2::{connect_backend, V2Error}; use thiserror::Error; /// App Configuration @@ -204,7 +204,7 @@ enum RouterError { #[error("Argument validation error: {0}")] ArgumentValidation(String), #[error("Backend failed: {0}")] - Backend(#[from] V3Error), + Backend(#[from] V2Error), #[error("WebServer error: {0}")] WebServer(#[from] server::WebServerError), #[error("Tokio runtime failed to start: {0}")] diff --git a/backends/v2/src/queue.rs b/backends/v2/src/queue.rs index f8123b57..bf52900f 100644 --- a/backends/v2/src/queue.rs +++ b/backends/v2/src/queue.rs @@ -1,20 +1,17 @@ -use crate::block_allocator::{BlockAllocation, BlockAllocator}; -use crate::client; use crate::client::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::min; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; use text_generation_router::validation::{ - Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, - ValidStoppingParameters, + ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; -use tracing::{info_span, instrument, Instrument, Span}; +use tracing::{info_span, instrument, Span}; /// Queue entry #[derive(Debug)] @@ -31,8 +28,6 @@ pub(crate) struct Entry { pub queue_time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, - /// Block Allocation - pub block_allocation: Option, } /// Request Queue @@ -46,10 +41,8 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, - prefix_caching: bool, window_size: Option, speculate: u32, - max_batch_total_tokens: u32, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -58,17 +51,14 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, - prefix_caching, window_size, speculate, - max_batch_total_tokens, queue_receiver, )); Self { queue_sender } } - /// Append an entry to the queue #[instrument(skip_all)] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state @@ -111,20 +101,11 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, - prefix_caching: bool, window_size: Option, speculate: u32, - max_batch_total_tokens: u32, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new( - requires_padding, - block_size, - prefix_caching, - window_size, - speculate, - max_batch_total_tokens, - ); + let mut state = State::new(requires_padding, block_size, window_size, speculate); while let Some(cmd) = receiver.recv().await { match cmd { @@ -139,14 +120,12 @@ async fn queue_task( token_budget, response_sender, span, - } => { - let next_batch = state - .next_batch(min_size, max_size, prefill_token_budget, token_budget) - .instrument(span) - .await; + } => span.in_scope(|| { + let next_batch = + state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); - } + }), } } } @@ -163,6 +142,9 @@ struct State { /// Id of the next batch next_batch_id: u64, + /// Whether the model is using padding + requires_padding: bool, + /// Paged Attention block size block_size: u32, @@ -171,37 +153,23 @@ struct State { /// Speculation amount speculate: u32, - - /// Paged Attention Block Allocation - block_allocator: Option, } impl State { fn new( requires_padding: bool, block_size: u32, - prefix_caching: bool, window_size: Option, speculate: u32, - max_batch_total_tokens: u32, ) -> Self { - let block_allocator = (!requires_padding).then(|| { - BlockAllocator::new( - max_batch_total_tokens, - block_size, - prefix_caching, - window_size, - ) - }); - Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, + requires_padding, block_size, window_size, speculate, - block_allocator, } } @@ -217,7 +185,7 @@ impl State { } // Get the next batch - async fn next_batch( + fn next_batch( &mut self, min_size: Option, max_size: Option, @@ -252,14 +220,16 @@ impl State { let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); next_batch_span.follows_from(Span::current()); - let mut batch = Vec::with_capacity(self.entries.len()); + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; - let mut max_blocks = 0; // Pop entries starting from the front of the queue - 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { + while let Some((id, mut entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -268,113 +238,44 @@ impl State { continue; } - let block_allocation = match &self.block_allocator { - None => { - // We pad to max input length in the Python shards - // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch.len() + 1) as u32 * max_input_length; + if self.requires_padding { + // We pad to max input length in the Python shards + // We need to take these padding tokens into the equation + max_input_length = max_input_length.max(entry.request.input_length); + prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length + } else { + // pad to block size + prefill_tokens += ((entry.request.input_length + self.block_size - 1) + / self.block_size) + * self.block_size; + } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; - let total_tokens = prefill_tokens + decode_tokens + self.speculate; + if self.requires_padding { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } else { + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; - if prefill_tokens > prefill_token_budget || total_tokens > token_budget { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break 'entry_loop; - } - None - } - Some(_block_allocator) => { - prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; + // pad to block size + decode_tokens += + ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; + } - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } - - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - - // If users wants the prefill logprobs, we cannot reuse the cache. - // So no input_ids for the radix tree. - let input_ids = if entry.request.decoder_input_details { - None - } else { - entry.request.input_ids.clone() - }; - - Some((tokens, input_ids)) - } - }; - batch.push((id, entry, block_allocation)); - if Some(batch.len()) == max_size { + if prefill_tokens > prefill_token_budget + || (prefill_tokens + decode_tokens + self.speculate) > token_budget + { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + self.entries.push_front((id, entry)); break; } - } - // Empty batch - if batch.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - - // XXX We haven't allocated yet, so we're allowed to ditch the results. - // Check if our batch is big enough - if let Some(min_size) = min_size { - // Batch is too small - if batch.len() < min_size { - // Add back entries to the queue in the correct order - for (id, entry, _) in batch.into_iter().rev() { - self.entries.push_front((id, entry)); - } - return None; - } - } - - let mut batch_requests = Vec::with_capacity(self.entries.len()); - let mut batch_entries = - IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - - for (id, mut entry, block_allocation) in batch { - let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = - (block_allocation, &self.block_allocator) - { - tracing::debug!("Allocating {tokens} with {input_ids:?}"); - match block_allocator.allocate(tokens, input_ids).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - continue; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } - } else { - None - }; tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); @@ -384,40 +285,11 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots, prefix_len) = match &block_allocation { - None => (Vec::new(), Vec::new(), 0), - Some(block_allocation) => ( - block_allocation.blocks.clone(), - block_allocation.slots.clone(), - block_allocation.prefix_len, - ), - }; - - entry.block_allocation = block_allocation; - batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - input_chunks: Some(client::Input { - chunks: entry - .request - .inputs - .clone() - .into_iter() - .map(|c| client::InputChunk { - chunk: Some(match c { - Chunk::Text(text) => client::Chunk::Text(text), - Chunk::Image(image) => client::Chunk::Image(client::Image { - data: image.data, - mimetype: image.mimetype, - }), - }), - }) - .collect(), - }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, - add_special_tokens: entry.request.add_special_tokens, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), @@ -425,23 +297,39 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, - blocks, - slots, - prefix_len, - adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); + + // Check if max_size + if Some(batch_requests.len()) == max_size { + break; + } } // Empty batch if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); + tracing::debug!("Filtered out all entries"); return None; } + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch_requests.len() < min_size { + // Add back entries to the queue in the correct order + for r in batch_requests.into_iter().rev() { + let id = r.id; + let entry = batch_entries.remove(&id).unwrap(); + self.entries.push_front((id, entry)); + } + + return None; + } + } + // Final batch size let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); @@ -451,7 +339,6 @@ impl State { requests: batch_requests, size, max_tokens: (prefill_tokens + decode_tokens), - max_blocks, }; // Increment batch id self.next_batch_id += 1; @@ -516,9 +403,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { - use std::sync::Arc; - use super::*; + use std::sync::Arc; use tracing::info_span; fn default_entry() -> ( @@ -560,14 +446,13 @@ mod tests { temp_span: None, queue_time: Instant::now(), batch_time: None, - block_allocation: None, }; (entry, receiver_tx) } - #[tokio::test] - async fn test_append() { - let mut state = State::new(false, 1, false, None, 0, 16); + #[test] + fn test_append() { + let mut state = State::new(false, 1, None, 0); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -581,23 +466,23 @@ mod tests { assert_eq!(id, 0); } - #[tokio::test] - async fn test_next_batch_empty() { - let mut state = State::new(false, 1, false, None, 0, 16); + #[test] + fn test_next_batch_empty() { + let mut state = State::new(false, 1, None, 0); - assert!(state.next_batch(None, None, 1, 1).await.is_none()); - assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); + assert!(state.next_batch(None, None, 1, 1).is_none()); + assert!(state.next_batch(Some(1), None, 1, 1).is_none()); } - #[tokio::test] - async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + #[test] + fn test_next_batch_min_size() { + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&0)); assert!(entries.contains_key(&1)); @@ -613,7 +498,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - assert!(state.next_batch(Some(2), None, 2, 2).await.is_none()); + assert!(state.next_batch(Some(2), None, 2, 2).is_none()); assert_eq!(state.next_id, 3); assert_eq!(state.entries.len(), 1); @@ -621,15 +506,15 @@ mod tests { assert_eq!(id, 2); } - #[tokio::test] - async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, false, None, 0, 16); + #[test] + fn test_next_batch_max_size() { + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert!(entries.get(&0).unwrap().batch_time.is_some()); @@ -641,15 +526,15 @@ mod tests { assert_eq!(state.next_batch_id, 1); } - #[tokio::test] - async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, false, None, 0, 2); + #[test] + fn test_next_batch_token_budget() { + let mut state = State::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); state.append(entry2); - let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); assert_eq!(entries.len(), 1); assert!(entries.contains_key(&0)); assert_eq!(batch.id, 0); @@ -662,7 +547,7 @@ mod tests { let (entry3, _guard3) = default_entry(); state.append(entry3); - let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap(); + let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); assert_eq!(entries.len(), 2); assert!(entries.contains_key(&1)); assert!(entries.contains_key(&2)); @@ -676,14 +561,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, None, 0); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, None, 0); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -691,7 +576,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -724,7 +609,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -740,7 +625,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, None, 0); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -765,7 +650,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, false, None, 2, 16); + let queue = Queue::new(false, 1, None, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -784,7 +669,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, false, None, 0, 16); + let queue = Queue::new(false, 1, None, 0); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v2/src/radix.rs b/backends/v2/src/radix.rs deleted file mode 100644 index 8a544891..00000000 --- a/backends/v2/src/radix.rs +++ /dev/null @@ -1,876 +0,0 @@ -use crate::block_allocator::{Allocator, BlockAllocation}; -use slotmap::{DefaultKey, SlotMap}; -use std::hash::{Hash, Hasher}; -use std::{ - collections::{BTreeSet, HashMap}, - sync::Arc, -}; - -fn hash(slice: &[u32]) -> u64 { - assert!(!slice.is_empty()); - if slice.len() == 1 { - slice[0] as u64 - } else { - let mut s = std::hash::DefaultHasher::new(); - slice.hash(&mut s); - s.finish() - } -} - -pub struct RadixAllocator { - allocation_id: u64, - - allocations: HashMap, - - cache_blocks: RadixTrie, - - /// Blocks that are immediately available for allocation. - free_blocks: Vec, - - #[allow(dead_code)] - // This isn't used because the prefix need to match without the windowing - // mecanism. This at worst is overallocating, not necessarily being wrong. - window_size: Option, - - block_size: u32, -} - -impl RadixAllocator { - pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { - RadixAllocator { - allocation_id: 0, - allocations: HashMap::new(), - cache_blocks: RadixTrie::new(block_size as usize), - - // Block 0 is reserved for health checks. - free_blocks: (1..n_blocks).collect(), - window_size, - block_size, - } - } - - fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { - if self.free_blocks.len() < n_blocks_needed { - // This is a bit annoying, we first extend the free list and then - // split it off again below. This is because we need to put it on - // the free list if we cannot allocate enough blocks. This is only - // temporary, the trie needs to be able to report whether it can - // allocate the requested amount. Just not implemented yet. - tracing::debug!( - "Free blocks {} need {n_blocks_needed}", - self.free_blocks.len() - ); - self.free_blocks.extend( - self.cache_blocks - .evict(n_blocks_needed - self.free_blocks.len()), - ); - } - - if self.free_blocks.len() >= n_blocks_needed { - Some( - self.free_blocks - .split_off(self.free_blocks.len() - n_blocks_needed), - ) - } else { - None - } - } -} - -// Allocator trait -impl Allocator for RadixAllocator { - fn allocate( - &mut self, - tokens: u32, - prefill_tokens: Option>>, - ) -> Option { - let mut blocks = vec![]; - let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { - let node_id = self - .cache_blocks - .find(prefill_tokens.as_slice(), &mut blocks); - node_id - } else { - self.cache_blocks.root_id() - }; - - // Even if this allocation fails below, we need to increase he - // refcount to ensure that the prefix that was found is not evicted. - self.cache_blocks - .incref(prefix_node) - .expect("Failed to increment refcount"); - - let prefix_len = blocks.len() * self.block_size as usize; - let suffix_len = tokens - prefix_len as u32; - - let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; - - tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); - - match self.alloc_or_reclaim(suffix_blocks as usize) { - Some(suffix_blocks) => blocks.extend(suffix_blocks), - None => { - tracing::debug!("Cannot allocate {:?}", self.cache_blocks); - tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens"); - tracing::debug!("Block size {}", self.block_size); - self.cache_blocks - .decref(prefix_node) - .expect("Failed to decrement refcount"); - return None; - } - } - - // 1:1 mapping of blocks and slots. - let slots = if self.block_size == 1 { - blocks.clone() - } else { - let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); - 'slots: for block_id in &blocks { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - slots.push(s); - if slots.len() as u32 == tokens { - break 'slots; - } - } - } - slots - }; - - let allocation = RadixAllocation { - prefix_node, - cached_prefix_len: prefix_len, - prefill_tokens: prefill_tokens.clone(), - }; - - self.allocation_id += 1; - self.allocations.insert(self.allocation_id, allocation); - - Some(BlockAllocation { - allocation_id: self.allocation_id, - block_allocator: None, - blocks, - slots, - prefix_len: prefix_len as u32, - }) - } - - fn free(&mut self, blocks: Vec, allocation_id: u64) { - let allocation = match self.allocations.remove(&allocation_id) { - Some(allocation) => allocation, - None => unreachable!("Tried to free an unknown allocation."), - }; - - self.cache_blocks - .decref(allocation.prefix_node) - .expect("Failed to decrement refcount"); - - if let Some(prefill_tokens) = allocation.prefill_tokens { - let prefill_tokens = prefill_tokens.as_slice(); - - // If there are prefill tokens that did not come from the cache, - // add them to the cache. - if prefill_tokens.len() > allocation.cached_prefix_len { - let aligned = - (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; - if aligned > 0 { - let prefix_len = self - .cache_blocks - .insert( - &prefill_tokens[..aligned], - &blocks[..aligned / self.block_size as usize], - ) - // Unwrap, failing is a programming error. - .expect("Failed to store prefill tokens"); - // We can have a prefill with the following structure: - // - // |---| From the prefix cache. - // A B C D E F G - //|--------| Found in the trie during insertion. - // - // This means that while processing this request there was a - // partially overlapping request that had A..=E in its - // prefill. In this case we need to free the blocks D E. - if prefix_len > allocation.cached_prefix_len { - self.free_blocks.extend( - &blocks[allocation.cached_prefix_len / self.block_size as usize - ..prefix_len / self.block_size as usize], - ); - } - } - } - - // Free non-prefill blocks. - self.free_blocks - .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); - } else { - self.free_blocks.extend(blocks); - } - } -} - -struct RadixAllocation { - prefix_node: NodeId, - cached_prefix_len: usize, - prefill_tokens: Option>>, -} - -// Radix trie that is heavily inspired by radix attention from sglang. -// -// The trie is optimized for prefix caching: -// -// - A normal radix trie stores discrete values. In this radix trie, -// inserting *abc* with value *xyz* will also enable lookup for -// *a* (*x*) and *ab* (*xy*). -// - As a result, every value is required to have the same length as -// the key. -// - We store additional information in each node, such as last access -// time and a reference count. - -#[derive(Debug)] -pub enum TrieError { - InvalidNodeId, - RefCountUnderflow, -} - -pub type NodeId = DefaultKey; - -#[derive(Debug)] -pub struct RadixTrie { - /// Identifier of the root nod. - root: DefaultKey, - - /// Leave node identifiers ordered by increasing recency. - leaves: BTreeSet<(u64, NodeId)>, - - /// All trie nodes. - nodes: SlotMap, - - /// Time as a monotonically increating counter to avoid the system - /// call that a real time lookup would require. - time: u64, - - /// All blocks need to be aligned with this - block_size: usize, -} - -impl RadixTrie { - /// Construct a new radix trie. - pub fn new(block_size: usize) -> Self { - let root = TrieNode::new(vec![], vec![], 0, None); - let mut nodes = SlotMap::new(); - let root = nodes.insert(root); - RadixTrie { - leaves: BTreeSet::new(), - nodes, - root, - time: 0, - block_size, - } - } - - /// Find the prefix of the given tokens. - /// - /// The blocks corresponding to the part of the prefix that could be found - /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. - /// Returns the identifier of the trie node that contains the longest - /// prefix. The node identifier can be used by callers to e.g. increase its - /// reference count. - /// - /// Using this method will update the access time of the traversed nodes. - pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { - self.time += 1; - self.find_(self.root, key, blocks) - } - - /// Find worker. - fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { - let node = &self.nodes[node_id]; - - if key.len() >= self.block_size { - let node_key = hash(&key[..self.block_size]); - if let Some(&child_id) = node.children.get(&node_key) { - self.update_access_time(child_id); - let child = self.nodes.get(child_id).expect("Invalid child identifier"); - let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); - assert_eq!(shared_prefix_len % self.block_size, 0); - blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); - - let key = &key[shared_prefix_len..]; - if !key.is_empty() { - node_id = self.find_(child_id, key, blocks); - } - } - } - - node_id - } - - /// Decrease the reference count of a node. - pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { - // We don't care about refcounting for root, since it will never - // be evicted. - if node_id == self.root { - return Ok(()); - } - - let node = self - .nodes - .get_mut(node_id) - .ok_or(TrieError::InvalidNodeId)?; - if node.ref_count == 0 { - return Err(TrieError::RefCountUnderflow); - } - - node.ref_count -= 1; - if node.ref_count == 0 { - assert!( - node.children.is_empty(), - "Nodes with children must have refcount > 0" - ); - - self.leaves.insert((node.last_accessed, node_id)); - } - - Ok(()) - } - - /// Increase the reference count of a node. - pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { - if node_id == self.root { - return Ok(()); - } - - let node = self - .nodes - .get_mut(node_id) - .ok_or(TrieError::InvalidNodeId)?; - if node.ref_count == 0 { - self.leaves.remove(&(node.last_accessed, node_id)); - } - node.ref_count += 1; - - Ok(()) - } - - /// Evict `n_blocks` from the trie. - /// - /// Returns the evicted blocks. When the length is less than `n_blocks`, - /// not enough blocks could be evicted. - pub fn evict(&mut self, n_blocks: usize) -> Vec { - // NOTE: we don't return Result here. If any of the unwrapping fails, - // it's a programming error in the trie implementation, not a user - // error caused by e.g. an invalid argument. - - // TODO: add some bookkeeping in the future to check whether we can - // evict n_blocks and return `None` if we can't. We are now needlessly - // evicting prefixes from the cache in such a case. - let mut evicted = Vec::new(); - tracing::debug!("Evicting in search of {n_blocks}"); - - while let Some((last_access, node_id)) = self.leaves.pop_first() { - let blocks_needed = n_blocks.saturating_sub(evicted.len()); - tracing::debug!("Evicting node {node_id:?} "); - - let node = self.nodes.get(node_id).expect("Leave does not exist"); - assert_eq!( - node.ref_count, 0, - "Leaf must have refcount of 0, got {}", - node.ref_count - ); - - if blocks_needed >= node.blocks.len() { - // We need to evict the whole node if we need more blocks than it has. - let node = self.remove_node(node_id); - evicted.extend(node.blocks); - - if evicted.len() >= n_blocks { - break; - } - } else { - // The node has more blocks than needed, so we'll just remove - // the required number of blocks and leave the remaining blocks - // untouched. - let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); - - let truncate_blocks = node.blocks.len() - blocks_needed; - let truncate_tokens = truncate_blocks * self.block_size; - node.key.truncate(truncate_tokens); - evicted.extend(node.blocks.split_off(truncate_blocks)); - self.leaves.insert((last_access, node_id)); - break; - } - } - - evicted - } - - /// Insert a prefill along with its blocks. - /// - /// This method returns the length of the prefix that was already - /// in the trie. E.g. if the length is 10, this means that for - /// the first 10 elements of the tree **the blocks are not updated**. - pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { - self.time += 1; - let common = self.insert_(self.root, tokens, blocks)?; - Ok(common) - } - - /// Insertion worker. - fn insert_( - &mut self, - node_id: NodeId, - tokens: &[u32], - blocks: &[u32], - ) -> Result { - // TODO: in the future we may want to check that the blocks match for - // the part of the prefix that is already in the trie to detect - // mismatches. - - assert_eq!(tokens.len(), blocks.len() * self.block_size); - - let node_key = hash(&tokens[..self.block_size]); - if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) { - self.update_access_time(child_id); - let child = self - .nodes - .get_mut(child_id) - // Unwrap here, since failure is a bug. - .expect("Child node does not exist"); - let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); - - // We are done, the prefix is already in the trie. - if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { - return Ok(shared_prefix_len); - } - - // The node's prefix is a prefix of the insertion prefix. - if shared_prefix_len == child.key.len() { - return Ok(shared_prefix_len - + self.insert_( - child_id, - &tokens[shared_prefix_len..], - &blocks[shared_prefix_len / self.block_size..], - )?); - } - - // The node's prefix and the insertion prefix only match partially, - // split the node to just contain the matching part. Then insert the - // remainder of the prefix into the node again - let child_id = self.split_node(child_id, shared_prefix_len); - let key = &tokens[shared_prefix_len..]; - let blocks = &blocks[shared_prefix_len / self.block_size..]; - Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) - } else { - self.add_node(node_id, tokens, blocks); - Ok(0) - } - } - - fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { - // We have to make the current node a child to ensure that its - // properties and node id stay the same. - - // This funcion unwraps, an invalid node_id is a programming error. - - let node = self - .nodes - .get_mut(node_id) - .expect("Node to-be split does not exist"); - let mut parent_key = node.key.split_off(prefix_len); - let prefix_blocks = prefix_len / self.block_size; - let mut parent_blocks = node.blocks.split_off(prefix_blocks); - - // Move first part of the prefix to the parent. We swap to avoid - // an allocation + copy for both splits of the key/blocks. - std::mem::swap(&mut node.key, &mut parent_key); - std::mem::swap(&mut node.blocks, &mut parent_blocks); - - let node_key = hash(&node.key[..self.block_size]); - - let grandparent_id = node.parent.expect("Node does not have a parent"); - let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); - self.add_node_to_parent(parent_id, node_key, node_id); - - // Reborrow to make the borrow checker happy. - let node = self - .nodes - .get_mut(node_id) - .expect("Node to-be split does not exist"); - node.parent = Some(parent_id); - - parent_id - } - - /// Create a node and add it to the parent. - fn add_node( - &mut self, - parent_id: NodeId, - key: impl Into>, - blocks: impl Into>, - ) -> NodeId { - let key = key.into(); - let blocks = blocks.into(); - let first = hash(&key[..self.block_size]); - - let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); - let child_id = self.nodes.insert(child); - - self.add_node_to_parent(parent_id, first, child_id); - self.leaves.insert((self.time, child_id)); - - child_id - } - - /// Add a node to the parent. - fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) { - // Unwrap here, passing in an unknown id is a programming error. - let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); - if parent.children.insert(hash, child_id).is_none() { - // Only increase reference count if child does not replace another child. - self.incref(parent_id) - .expect("Failed to increase parent refcount"); - } - } - - /// Remove a node from the trie. - fn remove_node(&mut self, node_id: NodeId) -> TrieNode { - // Unwrap here, passing in an unknown id is a programming error. - let node = self.nodes.remove(node_id).expect("Unknown node"); - assert!( - node.children.is_empty(), - "Tried to remove a node with {} children", - node.children.len() - ); - let parent_id = node.parent.expect("Attempted to remove root node"); - let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); - - let node_key = hash(&node.key[..self.block_size]); - parent.children.remove(&node_key); - self.decref(parent_id) - .expect("Failed to decrease parent refcount"); - node - } - - fn update_access_time(&mut self, node_id: NodeId) { - // Unwrap here, passing in an unknown id is a programming error. - let node = self.nodes.get_mut(node_id).expect("Unknown node"); - - // Update the ordered leaves set if the node is a leave. - if self.leaves.remove(&(node.last_accessed, node_id)) { - self.leaves.insert((self.time, node_id)); - } - - node.last_accessed = self.time; - } - - #[allow(dead_code)] - #[doc(hidden)] - /// Print debugging output for the trie. - /// - /// In contrast to `Debug` nicely formatted. - pub fn print_debug(&self) { - self.print_debug_(self.root, 0); - } - - fn print_debug_(&self, node_id: NodeId, indent: usize) { - let node = &self.nodes[node_id]; - eprintln!( - "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", - " ".repeat(indent), - node_id, - node.key, - node.blocks, - node.ref_count, - node.last_accessed, - node.parent, - node.children - ); - for child_id in self.nodes[node_id].children.values() { - self.print_debug_(*child_id, indent + 2); - } - } - - pub(crate) fn root_id(&self) -> DefaultKey { - self.root - } -} - -/// Trie node. -#[derive(Debug)] -struct TrieNode { - blocks: Vec, - children: HashMap, - key: Vec, - last_accessed: u64, - parent: Option, - ref_count: usize, -} - -impl TrieNode { - fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { - TrieNode { - children: HashMap::new(), - key, - blocks, - last_accessed, - parent, - ref_count: 0, - } - } -} - -fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { - let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); - // NOTE: this is the case because the child node was chosen based on - // matching the first character of the key/prefix. - assert!(full > 0, "Prefixes must at least share 1 token"); - (full / block_size) * block_size -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - - #[test] - fn allocator_block_size() { - let mut cache = RadixAllocator::new(2, 12, None); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); - assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); - assert_eq!(allocation.prefix_len, 0); - cache.free(allocation.blocks.clone(), allocation.allocation_id); - - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); - assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); - assert_eq!(allocation.prefix_len, 4); - } - - #[test] - fn allocator_block_size_non_aligned() { - let mut cache = RadixAllocator::new(2, 12, None); - let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); - assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); - assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); - assert_eq!(allocation.prefix_len, 0); - cache.free(allocation.blocks.clone(), allocation.allocation_id); - - let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); - assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); - assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); - assert_eq!(allocation.prefix_len, 2); - } - - #[test] - fn allocator_reuses_prefixes() { - let mut cache = RadixAllocator::new(1, 12, None); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); - assert_eq!(allocation.blocks, allocation.slots); - assert_eq!(allocation.prefix_len, 0); - cache.free(allocation.blocks.clone(), allocation.allocation_id); - - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); - assert_eq!(allocation.prefix_len, 4); - } - - #[test] - fn allocator_collects_older_prefixes_first() { - let mut cache = RadixAllocator::new(1, 7, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); - assert_eq!(allocation1.prefix_len, 0); - - let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); - assert_eq!(allocation2.blocks, vec![1, 2]); - assert_eq!(allocation2.prefix_len, 0); - - cache.free(allocation1.blocks.clone(), allocation1.allocation_id); - cache.free(allocation2.blocks.clone(), allocation2.allocation_id); - - // We should get the blocks of the first allocation, since they are more recent. - let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); - assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); - assert_eq!(allocation3.prefix_len, 0); - } - - #[test] - fn allocator_frees_fully_overlapping_prefills() { - let mut cache = RadixAllocator::new(1, 10, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - - cache.free(allocation2.blocks.clone(), allocation2.allocation_id); - cache.free(allocation1.blocks.clone(), allocation1.allocation_id); - - let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation3.prefix_len, 4); - - // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. - assert_eq!(cache.free_blocks.len(), 5); - } - - #[test] - fn allocator_frees_partially_overlapping_prefills() { - let mut cache = RadixAllocator::new(1, 20, None); - let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); - assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); - assert_eq!(allocation1.prefix_len, 0); - - cache.free(allocation1.blocks.clone(), allocation1.allocation_id); - - let allocation2 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) - .unwrap(); - assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); - assert_eq!(allocation2.prefix_len, 2); - - let allocation3 = cache - .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) - .unwrap(); - assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); - assert_eq!(allocation3.prefix_len, 2); - - cache.free(allocation3.blocks.clone(), allocation3.allocation_id); - cache.free(allocation2.blocks.clone(), allocation2.allocation_id); - - // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. - assert_eq!(cache.free_blocks.len(), 11); - - let allocation4 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) - .unwrap(); - assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); - assert_eq!(allocation4.prefix_len, 6); - assert_eq!(cache.free_blocks.len(), 11); - - let allocation5 = cache - .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) - .unwrap(); - assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); - assert_eq!(allocation5.prefix_len, 6); - assert_eq!(cache.free_blocks.len(), 11); - } - - #[test] - fn trie_insertions_have_correct_prefix_len() { - let mut trie = RadixTrie::new(1); - - assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); - - // Already exists. - assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); - - // Completely new at root-level - assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); - - // Contains full prefix, but longer. - assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); - - // Shares partial prefix, we need a split. - assert_eq!( - trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) - .unwrap(), - 4 - ); - } - - #[test] - fn trie_insertions_block_size() { - let mut trie = RadixTrie::new(2); - - assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); - - // Already exists. - // But needs to be block_size aligned - assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); - - // Completely new at root-level - assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); - - // Contains full prefix, but longer. - assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); - - // Shares partial prefix, we need a split. - assert_eq!( - trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) - .unwrap(), - 2 - ); - } - - #[test] - fn trie_get_returns_correct_blocks() { - let mut trie = RadixTrie::new(1); - trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); - trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); - trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); - trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) - .unwrap(); - - let mut blocks = Vec::new(); - trie.find(&[0], &mut blocks); - assert_eq!(blocks, vec![0]); - - blocks.clear(); - trie.find(&[0, 1, 2], &mut blocks); - assert_eq!(blocks, vec![0, 1, 2]); - - blocks.clear(); - trie.find(&[1, 2, 3], &mut blocks); - assert_eq!(blocks, vec![1, 2, 3]); - - blocks.clear(); - trie.find(&[0, 1, 2, 3], &mut blocks); - assert_eq!(blocks, vec![0, 1, 2, 3]); - - blocks.clear(); - trie.find(&[0, 1, 2, 3, 4], &mut blocks); - assert_eq!(blocks, vec![0, 1, 2, 3, 4]); - - blocks.clear(); - trie.find(&[0, 1, 2, 3, 5], &mut blocks); - assert_eq!(blocks, vec![0, 1, 2, 3, 5]); - } - - #[test] - fn trie_evict_removes_correct_blocks() { - let mut trie = RadixTrie::new(1); - trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); - trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) - .unwrap(); - trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); - trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); - - let mut blocks = Vec::new(); - - // Remove less than the leave blocks. - assert_eq!(trie.evict(1), vec![7]); - trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); - assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); - - // Refresh other leaf. - trie.find(&[0, 1, 2, 3, 4], &mut blocks); - trie.find(&[1, 2, 3], &mut blocks); - - // Remove the leave blocks exactly. - assert_eq!(trie.evict(2), vec![5, 6]); - blocks.clear(); - trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); - assert_eq!(blocks, vec![0, 1, 2, 3]); - - trie.find(&[1, 2, 3], &mut blocks); - - // Remove more than the leave blocks. - assert_eq!(trie.evict(3), vec![4, 3, 2]); - blocks.clear(); - trie.find(&[0, 1, 2, 3, 4], &mut blocks); - assert_eq!(blocks, vec![0, 1]); - - // Clear out the whole trie. - assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); - } -} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 4a2341da..1c9d5620 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -8,9 +8,11 @@ use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, Message, PrefillToken, Token, }; +use async_stream::stream; use async_trait::async_trait; use chat_template::ChatTemplate; use futures::future::try_join_all; +use futures::Stream; use minijinja::ErrorKind; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -87,7 +89,14 @@ impl Infer { pub(crate) async fn generate_stream<'a>( &'a self, request: GenerateRequest, - ) -> Result { + ) -> Result< + ( + OwnedSemaphorePermit, + u32, // input_length + impl Stream> + 'a, + ), + InferError, + > { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -107,9 +116,18 @@ impl Infer { })?; let input_length = valid_request.input_length; - let generation_stream = self.backend.schedule(valid_request)?; + let mut generation_stream = self.backend.schedule(valid_request)?; - Ok((permit, input_length, generation_stream)) + // Wrap generation stream to update the backend health if the stream contains an error + let final_stream = stream! { + while let Some(response) = generation_stream.next().await { + yield response.inspect_err(|_err| { + self.backend_health.store(false, Ordering::SeqCst); + }) + } + }; + + Ok((permit, input_length, final_stream)) } /// Tokenizer the input @@ -278,13 +296,6 @@ impl Infer { } } -/// Type alias for generation responses -pub(crate) type GenerateStreamResponse = ( - OwnedSemaphorePermit, - u32, // input_length - UnboundedReceiverStream>, -); - #[derive(Debug)] pub struct GeneratedText { pub text: String, diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs deleted file mode 100644 index 6a91a433..00000000 --- a/router/src/infer/v2/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod queue; -mod scheduler; - -pub(crate) use scheduler::BackendV2; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs deleted file mode 100644 index 696cbfc8..00000000 --- a/router/src/infer/v2/queue.rs +++ /dev/null @@ -1,675 +0,0 @@ -use crate::infer::{InferError, InferStreamResponse}; -use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, -}; -use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::min; -use std::collections::VecDeque; -use text_generation_client::v2::{ - Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, -}; -use text_generation_client::ChunksToString; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::Instant; -use tracing::{info_span, instrument, Span}; - -/// Queue entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: ValidGenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: mpsc::UnboundedSender>, - /// Span that will live as long as entry - pub span: Span, - /// Temporary span used as a guard when logging inference, wait times... - pub temp_span: Option, - /// Instant when this entry was queued - pub queue_time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, -} - -/// Request Queue -#[derive(Debug, Clone)] -pub(crate) struct Queue { - /// Channel to communicate with the background queue task - queue_sender: mpsc::UnboundedSender, -} - -impl Queue { - pub(crate) fn new( - requires_padding: bool, - block_size: u32, - window_size: Option, - speculate: u32, - ) -> Self { - // Create channel - let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); - - // Launch background queue task - tokio::spawn(queue_task( - requires_padding, - block_size, - window_size, - speculate, - queue_receiver, - )); - - Self { queue_sender } - } - - #[instrument(skip_all)] - pub(crate) fn append(&self, entry: Entry) { - // Send append command to the background task managing the state - // Unwrap is safe here - self.queue_sender - .send(QueueCommand::Append(Box::new(entry), Span::current())) - .unwrap(); - } - - // Get the next batch - #[instrument(skip(self))] - pub(crate) async fn next_batch( - &self, - min_size: Option, - max_size: Option, - prefill_token_budget: u32, - token_budget: u32, - ) -> Option { - // Create response channel - let (response_sender, response_receiver) = oneshot::channel(); - // Send next batch command to the background task managing the state - // Unwrap is safe here - self.queue_sender - .send(QueueCommand::NextBatch { - min_size, - max_size, - prefill_token_budget, - token_budget, - response_sender, - span: Span::current(), - }) - .unwrap(); - // Await on response channel - // Unwrap is safe here - response_receiver.await.unwrap() - } -} - -// Background task responsible of the queue state -async fn queue_task( - requires_padding: bool, - block_size: u32, - window_size: Option, - speculate: u32, - mut receiver: mpsc::UnboundedReceiver, -) { - let mut state = State::new(requires_padding, block_size, window_size, speculate); - - while let Some(cmd) = receiver.recv().await { - match cmd { - QueueCommand::Append(entry, span) => { - span.in_scope(|| state.append(*entry)); - metrics::gauge!("tgi_queue_size").increment(1.0); - } - QueueCommand::NextBatch { - min_size, - max_size, - prefill_token_budget, - token_budget, - response_sender, - span, - } => span.in_scope(|| { - let next_batch = - state.next_batch(min_size, max_size, prefill_token_budget, token_budget); - response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); - }), - } - } -} - -/// Queue State -#[derive(Debug)] -struct State { - /// Queue entries organized in a Vec - entries: VecDeque<(u64, Entry)>, - - /// Id of the next entry - next_id: u64, - - /// Id of the next batch - next_batch_id: u64, - - /// Whether the model is using padding - requires_padding: bool, - - /// Paged Attention block size - block_size: u32, - - /// Sliding window - window_size: Option, - - /// Speculation amount - speculate: u32, -} - -impl State { - fn new( - requires_padding: bool, - block_size: u32, - window_size: Option, - speculate: u32, - ) -> Self { - Self { - entries: VecDeque::with_capacity(128), - next_id: 0, - next_batch_id: 0, - requires_padding, - block_size, - window_size, - speculate, - } - } - - /// Append an entry to the queue - fn append(&mut self, mut entry: Entry) { - // Create a span that will live as long as the entry is in the queue waiting to be batched - let queue_span = info_span!(parent: &entry.span, "queued"); - entry.temp_span = Some(queue_span); - - // Push entry in the queue - self.entries.push_back((self.next_id, entry)); - self.next_id += 1; - } - - // Get the next batch - fn next_batch( - &mut self, - min_size: Option, - max_size: Option, - prefill_token_budget: u32, - token_budget: u32, - ) -> Option { - if self.entries.is_empty() { - tracing::debug!("No queue"); - return None; - } - - // Check if we have enough entries - if let Some(min_size) = min_size { - if self.entries.len() < min_size { - tracing::debug!("Not enough entries"); - return None; - } - } - - if let Some(max_size) = max_size { - if max_size == 0 { - tracing::debug!("No capacity"); - return None; - } - } - - // Pad prefill_token_budget to be a multiple of block size - let prefill_token_budget = - ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; - - // Create span for this batch to add context to inference calls - let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); - - let mut batch_requests = Vec::with_capacity(self.entries.len()); - let mut batch_entries = - IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - - let mut max_input_length = 0; - let mut prefill_tokens: u32 = 0; - let mut decode_tokens: u32 = 0; - - // Pop entries starting from the front of the queue - while let Some((id, mut entry)) = self.entries.pop_front() { - // Filter entries where the response receiver was dropped (== entries where the request - // was dropped by the client) - if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - tracing::debug!("Dropping entry"); - continue; - } - - if self.requires_padding { - // We pad to max input length in the Python shards - // We need to take these padding tokens into the equation - max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length - } else { - // pad to block size - prefill_tokens += ((entry.request.input_length + self.block_size - 1) - / self.block_size) - * self.block_size; - } - - if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; - } else { - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - - // pad to block size - decode_tokens += - ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; - } - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } - - tracing::debug!("Accepting entry"); - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - - batch_requests.push(Request { - id, - prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.chunks_to_string(), - truncate: entry.request.truncate, - parameters: Some(NextTokenChooserParameters::from( - entry.request.parameters.clone(), - )), - stopping_parameters: Some(StoppingCriteriaParameters::from( - entry.request.stopping_parameters.clone(), - )), - top_n_tokens: entry.request.top_n_tokens, - }); - // Set batch_time - entry.batch_time = Some(Instant::now()); - // Insert in batch_entries IntMap - batch_entries.insert(id, entry); - - // Check if max_size - if Some(batch_requests.len()) == max_size { - break; - } - } - - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filtered out all entries"); - return None; - } - - // Check if our batch is big enough - if let Some(min_size) = min_size { - // Batch is too small - if batch_requests.len() < min_size { - // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - self.entries.push_front((id, entry)); - } - - return None; - } - } - - // Final batch size - let size = batch_requests.len() as u32; - next_batch_span.record("batch_size", size); - - let batch = Batch { - id: self.next_batch_id, - requests: batch_requests, - size, - max_tokens: (prefill_tokens + decode_tokens), - }; - // Increment batch id - self.next_batch_id += 1; - - metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); - - Some((batch_entries, batch, next_batch_span)) - } -} - -type NextBatch = (IntMap, Batch, Span); - -#[derive(Debug)] -enum QueueCommand { - Append(Box, Span), - NextBatch { - min_size: Option, - max_size: Option, - prefill_token_budget: u32, - token_budget: u32, - response_sender: oneshot::Sender>, - span: Span, - }, -} - -impl From for NextTokenChooserParameters { - fn from(value: ValidParameters) -> Self { - let (grammar, grammar_type) = match value.grammar { - None => (String::new(), GrammarType::None), - - Some(grammar) => match grammar { - ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json), - ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex), - }, - }; - - Self { - temperature: value.temperature, - top_k: value.top_k, - top_p: value.top_p, - typical_p: value.typical_p, - do_sample: value.do_sample, - seed: value.seed, - repetition_penalty: value.repetition_penalty, - frequency_penalty: value.frequency_penalty, - watermark: value.watermark, - grammar, - grammar_type: grammar_type.into(), - } - } -} - -impl From for StoppingCriteriaParameters { - fn from(value: ValidStoppingParameters) -> Self { - Self { - max_new_tokens: value.max_new_tokens, - stop_sequences: value.stop_sequences, - ignore_eos_token: value.ignore_eos_token, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tracing::info_span; - - fn default_entry() -> ( - Entry, - mpsc::UnboundedReceiver>, - ) { - let (response_tx, receiver_tx) = mpsc::unbounded_channel(); - - let entry = Entry { - request: ValidGenerateRequest { - inputs: vec![], - input_length: 0, - truncate: 0, - decoder_input_details: false, - parameters: ValidParameters { - temperature: 0.0, - top_k: 0, - top_p: 0.0, - typical_p: 0.0, - do_sample: false, - seed: 0, - repetition_penalty: 0.0, - frequency_penalty: 0.0, - watermark: false, - grammar: None, - }, - stopping_parameters: ValidStoppingParameters { - ignore_eos_token: false, - max_new_tokens: 1, - stop_sequences: vec![], - }, - top_n_tokens: 0, - adapter_id: None, - }, - response_tx, - span: info_span!("entry"), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - }; - (entry, receiver_tx) - } - - #[test] - fn test_append() { - let mut state = State::new(false, 1, None, 0); - let (entry, _guard) = default_entry(); - - assert_eq!(state.next_id, 0); - assert_eq!(state.entries.len(), 0); - - state.append(entry); - - assert_eq!(state.next_id, 1); - assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0).unwrap(); - assert_eq!(id, 0); - } - - #[test] - fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0); - - assert!(state.next_batch(None, None, 1, 1).is_none()); - assert!(state.next_batch(Some(1), None, 1, 1).is_none()); - } - - #[test] - fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - state.append(entry1); - state.append(entry2); - - let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&0)); - assert!(entries.contains_key(&1)); - assert!(entries.get(&0).unwrap().batch_time.is_some()); - assert!(entries.get(&1).unwrap().batch_time.is_some()); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 2); - - assert_eq!(state.next_id, 2); - assert_eq!(state.entries.len(), 0); - assert_eq!(state.next_batch_id, 1); - - let (entry3, _guard3) = default_entry(); - state.append(entry3); - - assert!(state.next_batch(Some(2), None, 2, 2).is_none()); - - assert_eq!(state.next_id, 3); - assert_eq!(state.entries.len(), 1); - let (id, _) = state.entries.remove(0).unwrap(); - assert_eq!(id, 2); - } - - #[test] - fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - state.append(entry1); - state.append(entry2); - - let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); - assert_eq!(entries.len(), 1); - assert!(entries.contains_key(&0)); - assert!(entries.get(&0).unwrap().batch_time.is_some()); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 1); - - assert_eq!(state.next_id, 2); - assert_eq!(state.entries.len(), 1); - assert_eq!(state.next_batch_id, 1); - } - - #[test] - fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - state.append(entry1); - state.append(entry2); - - let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); - assert_eq!(entries.len(), 1); - assert!(entries.contains_key(&0)); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 1); - - assert_eq!(state.next_id, 2); - assert_eq!(state.entries.len(), 1); - assert_eq!(state.next_batch_id, 1); - - let (entry3, _guard3) = default_entry(); - state.append(entry3); - - let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&1)); - assert!(entries.contains_key(&2)); - assert_eq!(batch.id, 1); - assert_eq!(batch.size, 2); - - assert_eq!(state.next_id, 3); - assert_eq!(state.entries.len(), 0); - assert_eq!(state.next_batch_id, 2); - } - - #[tokio::test] - async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0); - let (entry, _guard) = default_entry(); - queue.append(entry); - } - - #[tokio::test] - async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0); - - assert!(queue.next_batch(None, None, 1, 1).await.is_none()); - assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); - } - - #[tokio::test] - async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - queue.append(entry1); - queue.append(entry2); - - let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&0)); - assert!(entries.contains_key(&1)); - assert!(entries.get(&0).unwrap().batch_time.is_some()); - assert!(entries.get(&1).unwrap().batch_time.is_some()); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 2); - - let (entry3, _guard3) = default_entry(); - queue.append(entry3); - - // Not enough requests pending - assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none()); - // Not enough token budget - assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none()); - // Ok - let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap(); - assert_eq!(entries2.len(), 1); - assert!(entries2.contains_key(&2)); - assert!(entries2.get(&2).unwrap().batch_time.is_some()); - assert_eq!(batch2.id, 1); - assert_eq!(batch2.size, 1); - } - - #[tokio::test] - async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - queue.append(entry1); - queue.append(entry2); - - let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap(); - assert_eq!(entries.len(), 1); - assert!(entries.contains_key(&0)); - assert!(entries.get(&0).unwrap().batch_time.is_some()); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 1); - } - - #[tokio::test] - async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - queue.append(entry1); - queue.append(entry2); - - let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap(); - assert_eq!(entries.len(), 1); - assert!(entries.contains_key(&0)); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 1); - - let (entry3, _guard3) = default_entry(); - queue.append(entry3); - - let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&1)); - assert!(entries.contains_key(&2)); - assert_eq!(batch.id, 1); - assert_eq!(batch.size, 2); - } - - #[tokio::test] - async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2); - let (entry1, _guard1) = default_entry(); - let (entry2, _guard2) = default_entry(); - queue.append(entry1); - queue.append(entry2); - - // Budget of 1 is not enough - assert!(queue.next_batch(None, None, 1, 1).await.is_none()); - - let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap(); - assert_eq!(entries.len(), 2); - assert!(entries.contains_key(&0)); - assert!(entries.contains_key(&1)); - assert_eq!(batch.id, 0); - assert_eq!(batch.size, 2); - } - - #[tokio::test] - async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0); - let (entry, _) = default_entry(); - queue.append(entry); - - assert!(queue.next_batch(None, None, 1, 1).await.is_none()); - } -} diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs deleted file mode 100644 index 0e5fc8a3..00000000 --- a/router/src/infer/v2/scheduler.rs +++ /dev/null @@ -1,1201 +0,0 @@ -/// Batching and inference logic -use crate::infer::v2::queue::{Entry, Queue}; -use crate::infer::{ - Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, -}; -use crate::validation::ValidGenerateRequest; -use crate::{Attention, FinishReason, PrefillToken, Token}; -use nohash_hasher::IntMap; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, -}; -use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::ClientError; -use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; -use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{info_span, instrument, Instrument, Span}; - -pub(crate) struct BackendV2 { - /// Request queue - queue: Queue, - /// Notify batcher on queue appends - batching_task_notifier: Arc, -} - -impl BackendV2 { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, - ) -> Self { - // Infer shared state - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .expect(&format!("Invalid attention was specified :`{attention}`")) - } else { - Attention::Paged - }; - let block_size = if attention == Attention::FlashDecoding { - 256 - } else { - 16 - }; - let queue = Queue::new(requires_padding, block_size, window_size, speculate); - let batching_task_notifier = Arc::new(Notify::new()); - - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - batching_task_notifier.clone(), - generation_health, - )); - - Self { - queue, - batching_task_notifier, - } - } -} - -impl Backend for BackendV2 { - #[instrument(skip_all)] - fn schedule( - &self, - request: ValidGenerateRequest, - permit: OwnedSemaphorePermit, - ) -> Result { - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.batching_task_notifier.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[allow(clippy::too_many_arguments)] -pub(crate) async fn batching_task( - mut client: ShardedClient, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - queue: Queue, - notifier: Arc, - generation_health: Arc, -) { - // Infinite loop - loop { - // Wait for a notification from the Infer struct - notifier.notified().await; - - // Get the next batch from the queue - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the queue - while let Some((mut entries, batch, span)) = queue - .next_batch( - None, - max_batch_size, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) - .await - { - let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) - .instrument(span) - .await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let batch_max_tokens = batch.max_tokens; - let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None - } else { - // Minimum batch size - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; - - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = - max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); - // Try to get a new batch - if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) - .await - { - // Tracking metrics - if min_size.is_some() { - metrics::counter!("tgi_batch_concat", "reason" => "backpressure") - .increment(1); - } else { - metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") - .increment(1); - } - - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to add the info that this entry is waiting - // because a new batch is being computed - let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); - // Add relationships - span.follows_from(&entry_waiting_span); - entry_waiting_span.follows_from(&span); - // Update entry - entry.temp_span = Some(entry_waiting_span); - }); - - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - prefill(&mut client, new_batch, &mut new_entries, &generation_health) - .instrument(span) - .await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } - } - - // Create span for this batch to add context to inference calls - let next_batch_size = entries.len(); - let next_batch_span = - info_span!(parent: None, "batch", batch_size = next_batch_size); - entries.iter_mut().for_each(|(_, entry)| { - // Create a new span to link the batch back to this entry - let entry_batch_span = info_span!(parent: &entry.span, "infer"); - // Add relationships - next_batch_span.follows_from(&entry_batch_span); - entry_batch_span.follows_from(&next_batch_span); - // Update entry - entry.temp_span = Some(entry_batch_span); - }); - - cached_batch = decode(&mut client, batches, &mut entries, &generation_health) - .instrument(next_batch_span) - .await; - waiting_tokens += 1; - } - metrics::gauge!("tgi_batch_current_size").set(0.0); - metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); - } - } -} - -#[instrument(skip_all)] -async fn prefill( - client: &mut ShardedClient, - batch: Batch, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_id = batch.id; - metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); - - match client.prefill(batch).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - // Update health - generation_health.store(false, Ordering::SeqCst); - let _ = client.clear_cache(Some(batch_id)).await; - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); - None - } - } -} - -#[instrument(skip_all)] -async fn decode( - client: &mut ShardedClient, - batches: Vec, - entries: &mut IntMap, - generation_health: &Arc, -) -> Option { - let start_time = Instant::now(); - let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); - - match client.decode(batches).await { - Ok((generations, next_batch, timings)) => { - // Update health - generation_health.store(true, Ordering::SeqCst); - - let start_filtering_time = Instant::now(); - // Send generated tokens and filter stopped entries - filter_send_generations(generations, entries); - - // Filter next batch and remove requests that were stopped - let next_batch = filter_batch(client, next_batch, entries).await; - - if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") - .record(concat_duration.as_secs_f64()); - } - metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") - .record(timings.forward.as_secs_f64()); - metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") - .record(timings.decode.as_secs_f64()); - metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") - .record(start_filtering_time.elapsed().as_secs_f64()); - metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") - .record(start_time.elapsed().as_secs_f64()); - metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - generation_health.store(false, Ordering::SeqCst); - for id in batch_ids { - let _ = client.clear_cache(Some(id)).await; - } - send_errors(err, entries); - metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); - None - } - } -} - -/// Filter a `batch` and remove all requests not present in `entries` -#[instrument(skip_all)] -async fn filter_batch( - client: &mut ShardedClient, - next_batch: Option, - entries: &IntMap, -) -> Option { - let mut batch = next_batch?; - - // No need to filter - if batch.size as usize == entries.len() { - return Some(batch); - } - - let id = batch.id; - - // Retain only requests that are still in entries - batch.request_ids.retain(|id| entries.contains_key(id)); - - if batch.request_ids.is_empty() { - // All requests have been filtered out - // Next batch is now empty - // Clear it from the Python shards cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.clear_cache(Some(id)).await.unwrap(); - None - } else { - // Filter Python shard cache - // We unwrap here as we need to panic since we cannot recover if this method fails - client.filter_batch(id, batch.request_ids).await.unwrap() - } -} - -/// Send one or multiple `InferStreamResponse` to Infer for all `entries` -/// and filter entries -#[instrument(skip_all)] -fn filter_send_generations(generations: Vec, entries: &mut IntMap) { - generations.into_iter().for_each(|generation| { - let id = generation.request_id; - // Get entry - // We can `expect` here as the request id should always be in the entries - let entry = entries - .get(&id) - .expect("ID not found in entries. This is a bug."); - - // Create and enter a span to link this function back to the entry - let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); - // Send generation responses back to the infer task - // If the receive an error from the Flume channel, it means that the client dropped the - // request and we need to stop generating hence why we unwrap_or(true) - let stopped = send_responses(generation, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - err - }).unwrap_or(true); - if stopped { - entries.remove(&id).expect("ID not found in entries. This is a bug."); - } - }); -} - -/// Send responses through the `entry` response channel -fn send_responses( - generation: Generation, - entry: &Entry, -) -> Result>>> { - // Return directly if the channel is disconnected - if entry.response_tx.is_closed() { - metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); - return Ok(true); - } - - let mut stopped = false; - - if let Some(prefill_tokens) = generation.prefill_tokens { - // Create Token objects - // We do that here instead of in the Python code as Rust for loops are faster - let prefill_tokens = prefill_tokens - .ids - .into_iter() - .zip(prefill_tokens.logprobs) - .zip(prefill_tokens.texts) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); - - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; - } - - // Create last Token - let tokens_ = generation.tokens.expect("Non empty tokens in generation"); - let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); - let mut iterator = tokens_ - .ids - .into_iter() - .zip(tokens_.logprobs) - .zip(tokens_.texts) - .zip(tokens_.is_special) - .enumerate() - .peekable(); - while let Some((i, (((id, logprob), text), special))) = iterator.next() { - let token = Token { - id, - text, - logprob, - special, - }; - let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { - top_tokens_ - .ids - .iter() - .zip(top_tokens_.logprobs.iter()) - .zip(top_tokens_.texts.iter()) - .zip(top_tokens_.is_special.iter()) - .map(|(((&id, &logprob), text), &special)| Token { - id, - text: text.to_string(), - logprob, - special, - }) - .collect() - } else { - vec![] - }; - match (&generation.generated_text, iterator.peek()) { - (Some(generated_text), None) => { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - token, - top_tokens, - generated_text: GeneratedText::from(generated_text.clone()), - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } - _ => { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; - } - } - } - - Ok(stopped) -} - -/// Send errors to Infer for all `entries` -#[instrument(skip_all)] -fn send_errors(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // Create and enter a span to link this function back to the entry - let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); - let err = InferError::GenerationError(error.to_string()); - metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); - tracing::error!("{err}"); - - // unwrap_or is valid here as we don't care if the receiver is gone. - entry - .response_tx - .send(Err(err)) - .unwrap_or(()); - }); -} - -impl From for GeneratedText { - fn from(value: text_generation_client::v2::GeneratedText) -> Self { - let v2_finish_reason = - text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap(); - let finish_reason = match v2_finish_reason { - text_generation_client::v2::FinishReason::Length => FinishReason::Length, - text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - text_generation_client::v2::FinishReason::StopSequence => FinishReason::StopSequence, - }; - - Self { - text: value.text, - generated_tokens: value.generated_tokens, - finish_reason, - seed: value.seed, - } - } -} - -// tests -#[cfg(test)] -mod tests { - use crate::infer::raise_exception; - use crate::{ChatTemplateInputs, TextMessage}; - use minijinja::Environment; - - #[test] - fn test_chat_template() { - let env = Environment::new(); - - let source = r#" - {% for message in messages %} - {% if message['role'] == 'system' %} - {% if message['content']%} - {{'### System:\n' + message['content']+'\n\n'}} - {% endif %} - {% elif message['role'] == 'user' %} - {{'### User:\n' + message['content']+'\n\n'}} - {% elif message['role'] == 'assistant' %} - {{'### Assistant:\n' + message['content']}} - {% endif %} - {% if loop.last and add_generation_prompt %} - {{ '### Assistant:\n' }} - {% endif %} - {% endfor %}"#; - - // trim all the whitespace - let source = source - .lines() - .map(|line| line.trim()) - .collect::>() - .join(""); - - let tmpl = env.template_from_str(&source); - - let chat_template_inputs = ChatTemplateInputs { - messages: vec![ - TextMessage { - role: "user".to_string(), - content: "Hi!".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "What is Deep Learning?".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "magic!".to_string(), - }, - ], - bos_token: Some("[BOS]"), - eos_token: Some("[EOS]"), - add_generation_prompt: true, - ..Default::default() - }; - - let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); - - assert_eq!( - result, - "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" - ); - } - - #[test] - fn test_chat_template_invalid_with_raise() { - let mut env = Environment::new(); - env.add_function("raise_exception", raise_exception); - - let source = r#" - {{ bos_token }} - {% for message in messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if message['role'] == 'user' %} - {{ '[INST] ' + message['content'] + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ message['content'] + eos_token}} - {% else %} - {{ raise_exception('Only user and assistant roles are supported!') }} - {% endif %} - {% endfor %}"#; - - // trim all the whitespace - let source = source - .lines() - .map(|line| line.trim()) - .collect::>() - .join(""); - - let tmpl = env.template_from_str(&source); - - let chat_template_inputs = ChatTemplateInputs { - messages: vec![ - TextMessage { - role: "user".to_string(), - content: "Hi!".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "Hi again!".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "What is Deep Learning?".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "magic!".to_string(), - }, - ], - bos_token: Some("[BOS]"), - eos_token: Some("[EOS]"), - add_generation_prompt: true, - ..Default::default() - }; - - let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); - - match result { - Ok(_) => panic!("Should have failed"), - Err(e) => { - assert_eq!( - e.detail().unwrap(), - "Conversation roles must alternate user/assistant/user/assistant/..." - ); - } - } - } - - #[test] - fn test_chat_template_valid_with_raise() { - let mut env = Environment::new(); - env.add_function("raise_exception", raise_exception); - - let source = r#" - {{ bos_token }} - {% for message in messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if message['role'] == 'user' %} - {{ '[INST] ' + message['content'] + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ message['content'] + eos_token}} - {% else %} - {{ raise_exception('Only user and assistant roles are supported!') }} - {% endif %} - {% endfor %}"#; - - // trim all the whitespace - let source = source - .lines() - .map(|line| line.trim()) - .collect::>() - .join(""); - - let tmpl = env.template_from_str(&source); - - let chat_template_inputs = ChatTemplateInputs { - messages: vec![ - TextMessage { - role: "user".to_string(), - content: "Hi!".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "What is Deep Learning?".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "magic!".to_string(), - }, - ], - bos_token: Some("[BOS]"), - eos_token: Some("[EOS]"), - add_generation_prompt: true, - ..Default::default() - }; - - let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); - assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); - } - - #[test] - fn test_chat_template_valid_with_add_generation_prompt() { - let mut env = Environment::new(); - env.add_function("raise_exception", raise_exception); - - let source = r#" - {% for message in messages %} - {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} - {% endfor %} - {% if add_generation_prompt %} - {{ '<|im_start|>assistant\n' }} - {% endif %}"#; - - // trim all the whitespace - let source = source - .lines() - .map(|line| line.trim()) - .collect::>() - .join(""); - - let tmpl = env.template_from_str(&source); - - let chat_template_inputs = ChatTemplateInputs { - messages: vec![ - TextMessage { - role: "user".to_string(), - content: "Hi!".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "What is Deep Learning?".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "magic!".to_string(), - }, - ], - bos_token: Some("[BOS]"), - eos_token: Some("[EOS]"), - add_generation_prompt: true, - ..Default::default() - }; - - let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); - assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); - } - - struct ChatTemplateTestItem { - name: &'static str, - chat_template: &'static str, - input: ChatTemplateInputs<'static>, - target: &'static str, - } - - #[test] - fn test_many_chat_templates() { - let example_chat = vec![ - TextMessage { - role: "user".to_string(), - content: "Hello, how are you?".to_string(), - }, - TextMessage { - role: "assistant".to_string(), - content: "I'm doing great. How can I help you today?".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "I'd like to show off how chat templating works!".to_string(), - }, - ]; - - let example_chat_with_system = [TextMessage { - role: "system".to_string(), - content: "You are a friendly chatbot who always responds in the style of a pirate" - .to_string(), - }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); - - let test_default_templates = vec![ - ChatTemplateTestItem { - name: "_base", - chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", - }, - ChatTemplateTestItem { - name: "blenderbot", - chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "blenderbot_small", - chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "bloom", - chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "gpt_neox", - chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), - ..Default::default() - }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", - }, - ChatTemplateTestItem { - name: "gpt2", - chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), - ..Default::default() - }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", - }, - ChatTemplateTestItem { - name: "llama", - // NOTE: the `.strip()` has been replaced with `| trim` in the following template - chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content | trim + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat_with_system.clone(), - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", - }, - ChatTemplateTestItem { - name: "whisper", - chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), - ..Default::default() - }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", - }, - ]; - - #[allow(unused_variables)] // name is unused - for ChatTemplateTestItem { - name, - chat_template, - input, - target, - } in test_default_templates - { - let mut env = Environment::new(); - env.add_function("raise_exception", raise_exception); - let tmpl = env.template_from_str(chat_template); - let result = tmpl.unwrap().render(input).unwrap(); - assert_eq!(result, target); - } - - let test_custom_templates = vec![ - ChatTemplateTestItem { - name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)", - chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat_with_system.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHello, how are you?<|assistant|>\nI'm doing great. How can I help you today?<|user|>\nI'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)", - chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", - input: ChatTemplateInputs { - messages: vec![ - TextMessage { - role: "system".to_string(), - content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), - }, - TextMessage { - role: "user".to_string(), - content: "How many helicopters can a human eat in one sitting?".to_string(), - }, - ], - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>", - }, - ChatTemplateTestItem { - name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", - chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", - }, - ChatTemplateTestItem { - name: "mistralai/Mistral-7B-Instruct-v0.1", - chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", - }, - ChatTemplateTestItem { - name: "mistralai/Mixtral-8x7B-Instruct-v0.1", - chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]", - }, - ChatTemplateTestItem { - name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", - chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", - }, - ChatTemplateTestItem { - name: "openchat/openchat-3.5-0106", - // `.title()` has been replaced with `| upper` in the following template - chat_template: "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + (message['role'] | title) + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>", - }, - ChatTemplateTestItem { - name: "upstage/SOLAR-10.7B-Instruct-v1.0", - chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "codellama/CodeLlama-70b-Instruct-hf", - // NOTE: `.strip()` has been replaced with `| trim` in the following template - chat_template: "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\\n\\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\\nDestination: user\\n\\n '}}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n ", - }, - ChatTemplateTestItem { - name: "Deci/DeciLM-7B-instruct", - chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "Qwen/Qwen1.5-72B-Chat", - chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "deepseek-ai/deepseek-llm-7b-chat", - chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\\n\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some("<|begin▁of▁sentence|>"), - eos_token: Some("<|end▁of▁sentence|>"), - ..Default::default() - }, - target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n", - }, - ChatTemplateTestItem { - name: "h2oai/h2o-danube-1.8b-chat", - chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "internlm/internlm2-chat-7b", - chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", - }, - ChatTemplateTestItem { - name: "TheBloke/deepseek-coder-33B-instruct-AWQ", - chat_template: "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some("<|begin▁of▁sentence|>"), - eos_token: Some("<|EOT|>"), - ..Default::default() - }, - target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n", - }, - ChatTemplateTestItem { - name: "ericzzz/falcon-rw-1b-chat", - // `.strip()` has been replaced with `| trim` in the following template - chat_template: "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'] | trim }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some("<|endoftext|>"), - eos_token: Some("<|endoftext|>"), - ..Default::default() - }, - target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "abacusai/Smaug-34B-v0.1", - chat_template: "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", - }, - ChatTemplateTestItem { - name: "maywell/Synatra-Mixtral-8x7B", - chat_template: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!", - }, - ChatTemplateTestItem { - name: "deepseek-ai/deepseek-coder-33b-instruct", - chat_template: "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", - input: ChatTemplateInputs { - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some("<|begin▁of▁sentence|>"), - eos_token: Some(""), - ..Default::default() - }, - target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n", - }, - // NOT INCLUDED - // - meetkai/functionary-medium-v2.2 - // - fireworks-ai/firefunction-v1 - // https://github - ChatTemplateTestItem { - name: "maywell/PiVoT-MoE", - chat_template: "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}", - input: ChatTemplateInputs { - messages: example_chat_with_system.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some(""), - ..Default::default() - }, - target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", - }, - ]; - - #[allow(unused_variables)] // name is unused - for ChatTemplateTestItem { - name, - chat_template, - input, - target, - } in test_custom_templates - { - let mut env = Environment::new(); - env.add_function("raise_exception", raise_exception); - // trim all the whitespace - let chat_template = chat_template - .lines() - .map(|line| line.trim()) - .collect::>() - .join(""); - - let tmpl = env.template_from_str(&chat_template); - let result = tmpl.unwrap().render(input).unwrap(); - assert_eq!(result, target); - } - } -}