mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
added v2
This commit is contained in:
parent
6e105c8eb8
commit
a50e90e7e2
@ -13,7 +13,7 @@ use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV3 {
|
||||
pub struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
@ -22,7 +22,7 @@ pub struct BackendV3 {
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV3 {
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
@ -35,24 +35,20 @@ impl BackendV3 {
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||
|
||||
let attention: Attention = attention
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = attention.block_size();
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
@ -76,7 +72,7 @@ impl BackendV3 {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV3 {
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
@ -93,7 +89,6 @@ impl Backend for BackendV3 {
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
@ -168,15 +163,12 @@ pub(crate) async fn batching_task(
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||
// reallocation when using prefix caching.
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
@ -259,13 +251,13 @@ async fn prefill(
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
@ -497,8 +489,8 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v2_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
|
@ -1,209 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::radix::RadixAllocator;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BlockAllocation {
|
||||
pub allocation_id: u64,
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
|
||||
/// Prefix that was cached and for which the KV does not have to
|
||||
/// be recomputed.
|
||||
pub prefix_len: u32,
|
||||
|
||||
pub(crate) block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BlockAllocator {
|
||||
/// Channel to communicate with the background task
|
||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||
}
|
||||
|
||||
impl BlockAllocator {
|
||||
pub(crate) fn new(
|
||||
max_batch_total_tokens: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(block_allocator_task(
|
||||
max_batch_total_tokens / block_size,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
receiver,
|
||||
));
|
||||
|
||||
Self {
|
||||
block_allocator: sender,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(
|
||||
&self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
response_receiver.await.unwrap().map(|mut allocation| {
|
||||
allocation.block_allocator = Some(self.clone());
|
||||
allocation
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free {
|
||||
allocation_id,
|
||||
blocks,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn block_allocator_task(
|
||||
blocks: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||
} else {
|
||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||
};
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free {
|
||||
blocks,
|
||||
allocation_id,
|
||||
} => allocator.free(blocks, allocation_id),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
response_sender
|
||||
.send(allocator.allocate(tokens, prefill_tokens))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
allocation_id: u64,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = core::cmp::min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BlockAllocation {
|
||||
allocation_id: 0,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: 0,
|
||||
block_allocator: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
@ -1,11 +1,9 @@
|
||||
/// Single shard Client
|
||||
use crate::client::{pb, Chunk};
|
||||
use crate::client::pb;
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v2::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
@ -47,7 +45,7 @@ impl Client {
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
@ -119,23 +117,6 @@ impl Client {
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut input_chunks = Vec::new();
|
||||
input_chunks
|
||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||
if n_tokens == 0 {
|
||||
input_chunks.push(
|
||||
Chunk::Image(Image {
|
||||
// Safe unwrap, because we control the data.
|
||||
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||
mimetype: "image/jpeg;base64".to_string(),
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||
// been updated to support chunks.
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
@ -149,16 +130,8 @@ impl Client {
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
add_special_tokens: true,
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
@ -180,7 +153,6 @@ impl Client {
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
@ -194,8 +166,7 @@ impl Client {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: max_input_length,
|
||||
max_blocks: 0,
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
|
@ -12,10 +12,9 @@ mod grpc_client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use grpc_client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
@ -64,13 +63,6 @@ impl From<transport::Error> for ClientError {
|
||||
}
|
||||
}
|
||||
|
||||
// Small convenience re-wrapping of `Chunk`.
|
||||
impl From<Chunk> for InputChunk {
|
||||
fn from(chunk: Chunk) -> Self {
|
||||
InputChunk { chunk: Some(chunk) }
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
@ -1,13 +1,13 @@
|
||||
use crate::client::{ClientError, Result};
|
||||
/// Multi shard Client
|
||||
use crate::client::{ClientError, Result};
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::InfoResponse;
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
@ -218,11 +218,7 @@ impl Health for ShardedClient {
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
@ -243,18 +239,12 @@ impl Health for ShardedClient {
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
|
@ -1,11 +1,9 @@
|
||||
mod backend;
|
||||
pub mod block_allocator;
|
||||
mod client;
|
||||
mod queue;
|
||||
pub mod radix;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV3;
|
||||
pub(crate) use backend::BackendV2;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
@ -41,7 +39,7 @@ pub async fn connect_backend(
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
@ -65,7 +63,7 @@ pub async fn connect_backend(
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
@ -75,16 +73,16 @@ pub async fn connect_backend(
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V3Error::Connection)?;
|
||||
.map_err(V2Error::Connection)?;
|
||||
|
||||
// server is running on v3
|
||||
// server is running on v2
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V3Error::Cache)?;
|
||||
.map_err(V2Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
@ -97,7 +95,7 @@ pub async fn connect_backend(
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V3Error::Warmup)?,
|
||||
.map_err(V2Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
@ -111,7 +109,7 @@ pub async fn connect_backend(
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
let backend = BackendV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
@ -129,7 +127,7 @@ pub async fn connect_backend(
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V3Error {
|
||||
pub enum V2Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
|
@ -1,6 +1,6 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use text_generation_router_v3::{connect_backend, V3Error};
|
||||
use text_generation_router_v2::{connect_backend, V2Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
@ -204,7 +204,7 @@ enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
Backend(#[from] V2Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
|
@ -1,20 +1,17 @@
|
||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::client;
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use tracing::{info_span, instrument, Span};
|
||||
|
||||
/// Queue entry
|
||||
#[derive(Debug)]
|
||||
@ -31,8 +28,6 @@ pub(crate) struct Entry {
|
||||
pub queue_time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
pub batch_time: Option<Instant>,
|
||||
/// Block Allocation
|
||||
pub block_allocation: Option<BlockAllocation>,
|
||||
}
|
||||
|
||||
/// Request Queue
|
||||
@ -46,10 +41,8 @@ impl Queue {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -58,17 +51,14 @@ impl Queue {
|
||||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Send append command to the background task managing the state
|
||||
@ -111,20 +101,11 @@ impl Queue {
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
@ -139,14 +120,12 @@ async fn queue_task(
|
||||
token_budget,
|
||||
response_sender,
|
||||
span,
|
||||
} => {
|
||||
let next_batch = state
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.instrument(span)
|
||||
.await;
|
||||
} => span.in_scope(|| {
|
||||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -163,6 +142,9 @@ struct State {
|
||||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
@ -171,37 +153,23 @@ struct State {
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
BlockAllocator::new(
|
||||
max_batch_total_tokens,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
block_allocator,
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,7 +185,7 @@ impl State {
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
async fn next_batch(
|
||||
fn next_batch(
|
||||
&mut self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
@ -252,14 +220,16 @@ impl State {
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
let mut max_blocks = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
@ -268,27 +238,21 @@ impl State {
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_allocation = match &self.block_allocator {
|
||||
None => {
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
None
|
||||
}
|
||||
Some(_block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
@ -296,7 +260,11 @@ impl State {
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
}
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
@ -308,73 +276,6 @@ impl State {
|
||||
break;
|
||||
}
|
||||
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
// So no input_ids for the radix tree.
|
||||
let input_ids = if entry.request.decoder_input_details {
|
||||
None
|
||||
} else {
|
||||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
Some((tokens, input_ids))
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for (id, entry, _) in batch.into_iter().rev() {
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation) in batch {
|
||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
continue;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
@ -384,40 +285,11 @@ impl State {
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
let (blocks, slots, prefix_len) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new(), 0),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks.clone(),
|
||||
block_allocation.slots.clone(),
|
||||
block_allocation.prefix_len,
|
||||
),
|
||||
};
|
||||
|
||||
entry.block_allocation = block_allocation;
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
input_chunks: Some(client::Input {
|
||||
chunks: entry
|
||||
.request
|
||||
.inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|c| client::InputChunk {
|
||||
chunk: Some(match c {
|
||||
Chunk::Text(text) => client::Chunk::Text(text),
|
||||
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||
data: image.data,
|
||||
mimetype: image.mimetype,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
add_special_tokens: entry.request.add_special_tokens,
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
@ -425,23 +297,39 @@ impl State {
|
||||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
// Check if max_size
|
||||
if Some(batch_requests.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
tracing::debug!("Filtered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch_requests.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
@ -451,7 +339,6 @@ impl State {
|
||||
requests: batch_requests,
|
||||
size,
|
||||
max_tokens: (prefill_tokens + decode_tokens),
|
||||
max_blocks,
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
@ -516,9 +403,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
@ -560,14 +446,13 @@ mod tests {
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -581,23 +466,23 @@ mod tests {
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
@ -613,7 +498,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
@ -621,15 +506,15 @@ mod tests {
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
@ -641,15 +526,15 @@ mod tests {
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
@ -662,7 +547,7 @@ mod tests {
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
@ -676,14 +561,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
@ -691,7 +576,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -724,7 +609,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -740,7 +625,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -765,7 +650,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||
let queue = Queue::new(false, 1, None, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -784,7 +669,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -1,876 +0,0 @@
|
||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
fn hash(slice: &[u32]) -> u64 {
|
||||
assert!(!slice.is_empty());
|
||||
if slice.len() == 1 {
|
||||
slice[0] as u64
|
||||
} else {
|
||||
let mut s = std::hash::DefaultHasher::new();
|
||||
slice.hash(&mut s);
|
||||
s.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
|
||||
allocations: HashMap<u64, RadixAllocation>,
|
||||
|
||||
cache_blocks: RadixTrie,
|
||||
|
||||
/// Blocks that are immediately available for allocation.
|
||||
free_blocks: Vec<u32>,
|
||||
|
||||
#[allow(dead_code)]
|
||||
// This isn't used because the prefix need to match without the windowing
|
||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||
window_size: Option<u32>,
|
||||
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
RadixAllocator {
|
||||
allocation_id: 0,
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(block_size as usize),
|
||||
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
window_size,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
||||
if self.free_blocks.len() < n_blocks_needed {
|
||||
// This is a bit annoying, we first extend the free list and then
|
||||
// split it off again below. This is because we need to put it on
|
||||
// the free list if we cannot allocate enough blocks. This is only
|
||||
// temporary, the trie needs to be able to report whether it can
|
||||
// allocate the requested amount. Just not implemented yet.
|
||||
tracing::debug!(
|
||||
"Free blocks {} need {n_blocks_needed}",
|
||||
self.free_blocks.len()
|
||||
);
|
||||
self.free_blocks.extend(
|
||||
self.cache_blocks
|
||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||
);
|
||||
}
|
||||
|
||||
if self.free_blocks.len() >= n_blocks_needed {
|
||||
Some(
|
||||
self.free_blocks
|
||||
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Allocator trait
|
||||
impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
};
|
||||
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
self.cache_blocks
|
||||
.incref(prefix_node)
|
||||
.expect("Failed to increment refcount");
|
||||
|
||||
let prefix_len = blocks.len() * self.block_size as usize;
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||
|
||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
|
||||
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
|
||||
tracing::debug!("Block size {}", self.block_size);
|
||||
self.cache_blocks
|
||||
.decref(prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// 1:1 mapping of blocks and slots.
|
||||
let slots = if self.block_size == 1 {
|
||||
blocks.clone()
|
||||
} else {
|
||||
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
||||
'slots: for block_id in &blocks {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() as u32 == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
slots
|
||||
};
|
||||
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
cached_prefix_len: prefix_len,
|
||||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
Some(BlockAllocation {
|
||||
allocation_id: self.allocation_id,
|
||||
block_allocator: None,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: prefix_len as u32,
|
||||
})
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
.decref(allocation.prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
|
||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
let aligned =
|
||||
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||
if aligned > 0 {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(
|
||||
&prefill_tokens[..aligned],
|
||||
&blocks[..aligned / self.block_size as usize],
|
||||
)
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
// We can have a prefill with the following structure:
|
||||
//
|
||||
// |---| From the prefix cache.
|
||||
// A B C D E F G
|
||||
//|--------| Found in the trie during insertion.
|
||||
//
|
||||
// This means that while processing this request there was a
|
||||
// partially overlapping request that had A..=E in its
|
||||
// prefill. In this case we need to free the blocks D E.
|
||||
if prefix_len > allocation.cached_prefix_len {
|
||||
self.free_blocks.extend(
|
||||
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||
..prefix_len / self.block_size as usize],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Free non-prefill blocks.
|
||||
self.free_blocks
|
||||
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||
} else {
|
||||
self.free_blocks.extend(blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RadixAllocation {
|
||||
prefix_node: NodeId,
|
||||
cached_prefix_len: usize,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
}
|
||||
|
||||
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||
//
|
||||
// The trie is optimized for prefix caching:
|
||||
//
|
||||
// - A normal radix trie stores discrete values. In this radix trie,
|
||||
// inserting *abc* with value *xyz* will also enable lookup for
|
||||
// *a* (*x*) and *ab* (*xy*).
|
||||
// - As a result, every value is required to have the same length as
|
||||
// the key.
|
||||
// - We store additional information in each node, such as last access
|
||||
// time and a reference count.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TrieError {
|
||||
InvalidNodeId,
|
||||
RefCountUnderflow,
|
||||
}
|
||||
|
||||
pub type NodeId = DefaultKey;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RadixTrie {
|
||||
/// Identifier of the root nod.
|
||||
root: DefaultKey,
|
||||
|
||||
/// Leave node identifiers ordered by increasing recency.
|
||||
leaves: BTreeSet<(u64, NodeId)>,
|
||||
|
||||
/// All trie nodes.
|
||||
nodes: SlotMap<NodeId, TrieNode>,
|
||||
|
||||
/// Time as a monotonically increating counter to avoid the system
|
||||
/// call that a real time lookup would require.
|
||||
time: u64,
|
||||
|
||||
/// All blocks need to be aligned with this
|
||||
block_size: usize,
|
||||
}
|
||||
|
||||
impl RadixTrie {
|
||||
/// Construct a new radix trie.
|
||||
pub fn new(block_size: usize) -> Self {
|
||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||
let mut nodes = SlotMap::new();
|
||||
let root = nodes.insert(root);
|
||||
RadixTrie {
|
||||
leaves: BTreeSet::new(),
|
||||
nodes,
|
||||
root,
|
||||
time: 0,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the prefix of the given tokens.
|
||||
///
|
||||
/// The blocks corresponding to the part of the prefix that could be found
|
||||
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||
/// Returns the identifier of the trie node that contains the longest
|
||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||
/// reference count.
|
||||
///
|
||||
/// Using this method will update the access time of the traversed nodes.
|
||||
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
self.time += 1;
|
||||
self.find_(self.root, key, blocks)
|
||||
}
|
||||
|
||||
/// Find worker.
|
||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
let node = &self.nodes[node_id];
|
||||
|
||||
if key.len() >= self.block_size {
|
||||
let node_key = hash(&key[..self.block_size]);
|
||||
if let Some(&child_id) = node.children.get(&node_key) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
node_id = self.find_(child_id, key, blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
/// Decrease the reference count of a node.
|
||||
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
// We don't care about refcounting for root, since it will never
|
||||
// be evicted.
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
return Err(TrieError::RefCountUnderflow);
|
||||
}
|
||||
|
||||
node.ref_count -= 1;
|
||||
if node.ref_count == 0 {
|
||||
assert!(
|
||||
node.children.is_empty(),
|
||||
"Nodes with children must have refcount > 0"
|
||||
);
|
||||
|
||||
self.leaves.insert((node.last_accessed, node_id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Increase the reference count of a node.
|
||||
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
self.leaves.remove(&(node.last_accessed, node_id));
|
||||
}
|
||||
node.ref_count += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Evict `n_blocks` from the trie.
|
||||
///
|
||||
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||
/// not enough blocks could be evicted.
|
||||
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||
// it's a programming error in the trie implementation, not a user
|
||||
// error caused by e.g. an invalid argument.
|
||||
|
||||
// TODO: add some bookkeeping in the future to check whether we can
|
||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||
// evicting prefixes from the cache in such a case.
|
||||
let mut evicted = Vec::new();
|
||||
tracing::debug!("Evicting in search of {n_blocks}");
|
||||
|
||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||
tracing::debug!("Evicting node {node_id:?} ");
|
||||
|
||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||
assert_eq!(
|
||||
node.ref_count, 0,
|
||||
"Leaf must have refcount of 0, got {}",
|
||||
node.ref_count
|
||||
);
|
||||
|
||||
if blocks_needed >= node.blocks.len() {
|
||||
// We need to evict the whole node if we need more blocks than it has.
|
||||
let node = self.remove_node(node_id);
|
||||
evicted.extend(node.blocks);
|
||||
|
||||
if evicted.len() >= n_blocks {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The node has more blocks than needed, so we'll just remove
|
||||
// the required number of blocks and leave the remaining blocks
|
||||
// untouched.
|
||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||
|
||||
let truncate_blocks = node.blocks.len() - blocks_needed;
|
||||
let truncate_tokens = truncate_blocks * self.block_size;
|
||||
node.key.truncate(truncate_tokens);
|
||||
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||
self.leaves.insert((last_access, node_id));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
evicted
|
||||
}
|
||||
|
||||
/// Insert a prefill along with its blocks.
|
||||
///
|
||||
/// This method returns the length of the prefix that was already
|
||||
/// in the trie. E.g. if the length is 10, this means that for
|
||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||
self.time += 1;
|
||||
let common = self.insert_(self.root, tokens, blocks)?;
|
||||
Ok(common)
|
||||
}
|
||||
|
||||
/// Insertion worker.
|
||||
fn insert_(
|
||||
&mut self,
|
||||
node_id: NodeId,
|
||||
tokens: &[u32],
|
||||
blocks: &[u32],
|
||||
) -> Result<usize, TrieError> {
|
||||
// TODO: in the future we may want to check that the blocks match for
|
||||
// the part of the prefix that is already in the trie to detect
|
||||
// mismatches.
|
||||
|
||||
assert_eq!(tokens.len(), blocks.len() * self.block_size);
|
||||
|
||||
let node_key = hash(&tokens[..self.block_size]);
|
||||
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self
|
||||
.nodes
|
||||
.get_mut(child_id)
|
||||
// Unwrap here, since failure is a bug.
|
||||
.expect("Child node does not exist");
|
||||
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
||||
|
||||
// We are done, the prefix is already in the trie.
|
||||
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
||||
return Ok(shared_prefix_len);
|
||||
}
|
||||
|
||||
// The node's prefix is a prefix of the insertion prefix.
|
||||
if shared_prefix_len == child.key.len() {
|
||||
return Ok(shared_prefix_len
|
||||
+ self.insert_(
|
||||
child_id,
|
||||
&tokens[shared_prefix_len..],
|
||||
&blocks[shared_prefix_len / self.block_size..],
|
||||
)?);
|
||||
}
|
||||
|
||||
// The node's prefix and the insertion prefix only match partially,
|
||||
// split the node to just contain the matching part. Then insert the
|
||||
// remainder of the prefix into the node again
|
||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||
let key = &tokens[shared_prefix_len..];
|
||||
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||
} else {
|
||||
self.add_node(node_id, tokens, blocks);
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||
// We have to make the current node a child to ensure that its
|
||||
// properties and node id stay the same.
|
||||
|
||||
// This funcion unwraps, an invalid node_id is a programming error.
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
let mut parent_key = node.key.split_off(prefix_len);
|
||||
let prefix_blocks = prefix_len / self.block_size;
|
||||
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
|
||||
|
||||
// Move first part of the prefix to the parent. We swap to avoid
|
||||
// an allocation + copy for both splits of the key/blocks.
|
||||
std::mem::swap(&mut node.key, &mut parent_key);
|
||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||
|
||||
let node_key = hash(&node.key[..self.block_size]);
|
||||
|
||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||
|
||||
// Reborrow to make the borrow checker happy.
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
node.parent = Some(parent_id);
|
||||
|
||||
parent_id
|
||||
}
|
||||
|
||||
/// Create a node and add it to the parent.
|
||||
fn add_node(
|
||||
&mut self,
|
||||
parent_id: NodeId,
|
||||
key: impl Into<Vec<u32>>,
|
||||
blocks: impl Into<Vec<u32>>,
|
||||
) -> NodeId {
|
||||
let key = key.into();
|
||||
let blocks = blocks.into();
|
||||
let first = hash(&key[..self.block_size]);
|
||||
|
||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||
let child_id = self.nodes.insert(child);
|
||||
|
||||
self.add_node_to_parent(parent_id, first, child_id);
|
||||
self.leaves.insert((self.time, child_id));
|
||||
|
||||
child_id
|
||||
}
|
||||
|
||||
/// Add a node to the parent.
|
||||
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
if parent.children.insert(hash, child_id).is_none() {
|
||||
// Only increase reference count if child does not replace another child.
|
||||
self.incref(parent_id)
|
||||
.expect("Failed to increase parent refcount");
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a node from the trie.
|
||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||
assert!(
|
||||
node.children.is_empty(),
|
||||
"Tried to remove a node with {} children",
|
||||
node.children.len()
|
||||
);
|
||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
|
||||
let node_key = hash(&node.key[..self.block_size]);
|
||||
parent.children.remove(&node_key);
|
||||
self.decref(parent_id)
|
||||
.expect("Failed to decrease parent refcount");
|
||||
node
|
||||
}
|
||||
|
||||
fn update_access_time(&mut self, node_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
||||
|
||||
// Update the ordered leaves set if the node is a leave.
|
||||
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
||||
self.leaves.insert((self.time, node_id));
|
||||
}
|
||||
|
||||
node.last_accessed = self.time;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[doc(hidden)]
|
||||
/// Print debugging output for the trie.
|
||||
///
|
||||
/// In contrast to `Debug` nicely formatted.
|
||||
pub fn print_debug(&self) {
|
||||
self.print_debug_(self.root, 0);
|
||||
}
|
||||
|
||||
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||
let node = &self.nodes[node_id];
|
||||
eprintln!(
|
||||
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||
" ".repeat(indent),
|
||||
node_id,
|
||||
node.key,
|
||||
node.blocks,
|
||||
node.ref_count,
|
||||
node.last_accessed,
|
||||
node.parent,
|
||||
node.children
|
||||
);
|
||||
for child_id in self.nodes[node_id].children.values() {
|
||||
self.print_debug_(*child_id, indent + 2);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||
self.root
|
||||
}
|
||||
}
|
||||
|
||||
/// Trie node.
|
||||
#[derive(Debug)]
|
||||
struct TrieNode {
|
||||
blocks: Vec<u32>,
|
||||
children: HashMap<u64, NodeId>,
|
||||
key: Vec<u32>,
|
||||
last_accessed: u64,
|
||||
parent: Option<NodeId>,
|
||||
ref_count: usize,
|
||||
}
|
||||
|
||||
impl TrieNode {
|
||||
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
||||
TrieNode {
|
||||
children: HashMap::new(),
|
||||
key,
|
||||
blocks,
|
||||
last_accessed,
|
||||
parent,
|
||||
ref_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
||||
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
||||
// NOTE: this is the case because the child node was chosen based on
|
||||
// matching the first character of the key/prefix.
|
||||
assert!(full > 0, "Prefixes must at least share 1 token");
|
||||
(full / block_size) * block_size
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size_non_aligned() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.blocks, allocation.slots);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_collects_older_prefixes_first() {
|
||||
let mut cache = RadixAllocator::new(1, 7, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||
assert_eq!(allocation2.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// We should get the blocks of the first allocation, since they are more recent.
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation3.prefix_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_fully_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 10, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation3.prefix_len, 4);
|
||||
|
||||
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||
assert_eq!(cache.free_blocks.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_partially_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 20, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation2 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||
assert_eq!(allocation2.prefix_len, 2);
|
||||
|
||||
let allocation3 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation3.prefix_len, 2);
|
||||
|
||||
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation4 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||
assert_eq!(allocation4.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation5 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||
assert_eq!(allocation5.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_have_correct_prefix_len() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap(),
|
||||
4
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_block_size() {
|
||||
let mut trie = RadixTrie::new(2);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
// But needs to be block_size aligned
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
||||
.unwrap(),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_get_returns_correct_blocks() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
trie.find(&[0], &mut blocks);
|
||||
assert_eq!(blocks, vec![0]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_evict_removes_correct_blocks() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Remove less than the leave blocks.
|
||||
assert_eq!(trie.evict(1), vec![7]);
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
||||
|
||||
// Refresh other leaf.
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove the leave blocks exactly.
|
||||
assert_eq!(trie.evict(2), vec![5, 6]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove more than the leave blocks.
|
||||
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1]);
|
||||
|
||||
// Clear out the whole trie.
|
||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||
}
|
||||
}
|
@ -8,9 +8,11 @@ use crate::{
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
};
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use chat_template::ChatTemplate;
|
||||
use futures::future::try_join_all;
|
||||
use futures::Stream;
|
||||
use minijinja::ErrorKind;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
@ -87,7 +89,14 @@ impl Infer {
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let permit = self
|
||||
.clone()
|
||||
@ -107,9 +116,18 @@ impl Infer {
|
||||
})?;
|
||||
|
||||
let input_length = valid_request.input_length;
|
||||
let generation_stream = self.backend.schedule(valid_request)?;
|
||||
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
Ok((permit, input_length, generation_stream))
|
||||
// Wrap generation stream to update the backend health if the stream contains an error
|
||||
let final_stream = stream! {
|
||||
while let Some(response) = generation_stream.next().await {
|
||||
yield response.inspect_err(|_err| {
|
||||
self.backend_health.store(false, Ordering::SeqCst);
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok((permit, input_length, final_stream))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
@ -278,13 +296,6 @@ impl Infer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for generation responses
|
||||
pub(crate) type GenerateStreamResponse = (
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
|
@ -1,4 +0,0 @@
|
||||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::BackendV2;
|
@ -1,675 +0,0 @@
|
||||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v2::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::ChunksToString;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
|
||||
/// Queue entry
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Span that will live as long as entry
|
||||
pub span: Span,
|
||||
/// Temporary span used as a guard when logging inference, wait times...
|
||||
pub temp_span: Option<Span>,
|
||||
/// Instant when this entry was queued
|
||||
pub queue_time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
pub batch_time: Option<Instant>,
|
||||
}
|
||||
|
||||
/// Request Queue
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Queue {
|
||||
/// Channel to communicate with the background queue task
|
||||
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Send append command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender
|
||||
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
// Create response channel
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
// Send next batch command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender
|
||||
.send(QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
prefill_token_budget,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span: Span::current(),
|
||||
})
|
||||
.unwrap();
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
response_receiver.await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
max_size,
|
||||
prefill_token_budget,
|
||||
token_budget,
|
||||
response_sender,
|
||||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Queue State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Queue entries organized in a Vec
|
||||
entries: VecDeque<(u64, Entry)>,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
|
||||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
/// Sliding window
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
}
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
fn append(&mut self, mut entry: Entry) {
|
||||
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||
entry.temp_span = Some(queue_span);
|
||||
|
||||
// Push entry in the queue
|
||||
self.entries.push_back((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
fn next_batch(
|
||||
&mut self,
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if we have enough entries
|
||||
if let Some(min_size) = min_size {
|
||||
if self.entries.len() < min_size {
|
||||
tracing::debug!("Not enough entries");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_size) = max_size {
|
||||
if max_size == 0 {
|
||||
tracing::debug!("No capacity");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
}
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
// Check if max_size
|
||||
if Some(batch_requests.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filtered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch_requests.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
|
||||
let batch = Batch {
|
||||
id: self.next_batch_id,
|
||||
requests: batch_requests,
|
||||
size,
|
||||
max_tokens: (prefill_tokens + decode_tokens),
|
||||
};
|
||||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
}
|
||||
|
||||
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||
|
||||
#[derive(Debug)]
|
||||
enum QueueCommand {
|
||||
Append(Box<Entry>, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
max_size: Option<usize>,
|
||||
prefill_token_budget: u32,
|
||||
token_budget: u32,
|
||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||
span: Span,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||
fn from(value: ValidParameters) -> Self {
|
||||
let (grammar, grammar_type) = match value.grammar {
|
||||
None => (String::new(), GrammarType::None),
|
||||
|
||||
Some(grammar) => match grammar {
|
||||
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||
},
|
||||
};
|
||||
|
||||
Self {
|
||||
temperature: value.temperature,
|
||||
top_k: value.top_k,
|
||||
top_p: value.top_p,
|
||||
typical_p: value.typical_p,
|
||||
do_sample: value.do_sample,
|
||||
seed: value.seed,
|
||||
repetition_penalty: value.repetition_penalty,
|
||||
frequency_penalty: value.frequency_penalty,
|
||||
watermark: value.watermark,
|
||||
grammar,
|
||||
grammar_type: grammar_type.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||
fn from(value: ValidStoppingParameters) -> Self {
|
||||
Self {
|
||||
max_new_tokens: value.max_new_tokens,
|
||||
stop_sequences: value.stop_sequences,
|
||||
ignore_eos_token: value.ignore_eos_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||
) {
|
||||
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_length: 0,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: ValidParameters {
|
||||
temperature: 0.0,
|
||||
top_k: 0,
|
||||
top_p: 0.0,
|
||||
typical_p: 0.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: None,
|
||||
},
|
||||
stopping_parameters: ValidStoppingParameters {
|
||||
ignore_eos_token: false,
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
};
|
||||
(entry, receiver_tx)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
|
||||
state.append(entry);
|
||||
|
||||
assert_eq!(state.next_id, 1);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
let (id, _) = state.entries.remove(0).unwrap();
|
||||
assert_eq!(id, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
state.append(entry2);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
assert_eq!(state.next_id, 2);
|
||||
assert_eq!(state.entries.len(), 1);
|
||||
assert_eq!(state.next_batch_id, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
state.append(entry3);
|
||||
|
||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
assert_eq!(state.next_id, 3);
|
||||
assert_eq!(state.entries.len(), 0);
|
||||
assert_eq!(state.next_batch_id, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
// Not enough requests pending
|
||||
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||
// Not enough token budget
|
||||
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||
// Ok
|
||||
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||
assert_eq!(entries2.len(), 1);
|
||||
assert!(entries2.contains_key(&2));
|
||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch2.id, 1);
|
||||
assert_eq!(batch2.size, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 1);
|
||||
|
||||
let (entry3, _guard3) = default_entry();
|
||||
queue.append(entry3);
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&1));
|
||||
assert!(entries.contains_key(&2));
|
||||
assert_eq!(batch.id, 1);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
queue.append(entry2);
|
||||
|
||||
// Budget of 1 is not enough
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
|
||||
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||
assert_eq!(entries.len(), 2);
|
||||
assert!(entries.contains_key(&0));
|
||||
assert!(entries.contains_key(&1));
|
||||
assert_eq!(batch.id, 0);
|
||||
assert_eq!(batch.size, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None, 0);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user