mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
added v2
This commit is contained in:
parent
6e105c8eb8
commit
a50e90e7e2
@ -13,7 +13,7 @@ use tokio::time::Instant;
|
|||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
pub struct BackendV3 {
|
pub struct BackendV2 {
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
/// Notify batcher on queue appends
|
/// Notify batcher on queue appends
|
||||||
@ -22,7 +22,7 @@ pub struct BackendV3 {
|
|||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendV3 {
|
impl BackendV2 {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
@ -35,24 +35,20 @@ impl BackendV3 {
|
|||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let prefix_caching =
|
// Infer shared state
|
||||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
attention
|
||||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
.parse()
|
||||||
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||||
let attention: Attention = attention
|
} else {
|
||||||
.parse()
|
Attention::Paged
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
};
|
||||||
let block_size = attention.block_size();
|
let block_size = if attention == Attention::FlashDecoding {
|
||||||
|
256
|
||||||
let queue = Queue::new(
|
} else {
|
||||||
requires_padding,
|
16
|
||||||
block_size,
|
};
|
||||||
prefix_caching,
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
window_size,
|
|
||||||
speculate,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
);
|
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
@ -76,7 +72,7 @@ impl BackendV3 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Backend for BackendV3 {
|
impl Backend for BackendV2 {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
@ -93,7 +89,6 @@ impl Backend for BackendV3 {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
block_allocation: None,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
@ -168,15 +163,12 @@ pub(crate) async fn batching_task(
|
|||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// Minimum batch size
|
// Minimum batch size
|
||||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
|
||||||
// reallocation when using prefix caching.
|
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
let max_size =
|
let max_size =
|
||||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
@ -259,13 +251,13 @@ async fn prefill(
|
|||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
.record(timings.decode.as_secs_f64());
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||||
.record(start_time.elapsed().as_secs_f64());
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
@ -497,8 +489,8 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
|
|
||||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||||
fn from(value: crate::client::GeneratedText) -> Self {
|
fn from(value: crate::client::GeneratedText) -> Self {
|
||||||
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
let finish_reason = match v3_finish_reason {
|
let finish_reason = match v2_finish_reason {
|
||||||
crate::client::FinishReason::Length => FinishReason::Length,
|
crate::client::FinishReason::Length => FinishReason::Length,
|
||||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
@ -1,209 +0,0 @@
|
|||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
|
|
||||||
use crate::radix::RadixAllocator;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct BlockAllocation {
|
|
||||||
pub allocation_id: u64,
|
|
||||||
pub blocks: Vec<u32>,
|
|
||||||
pub slots: Vec<u32>,
|
|
||||||
|
|
||||||
/// Prefix that was cached and for which the KV does not have to
|
|
||||||
/// be recomputed.
|
|
||||||
pub prefix_len: u32,
|
|
||||||
|
|
||||||
pub(crate) block_allocator: Option<BlockAllocator>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for BlockAllocation {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
|
||||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct BlockAllocator {
|
|
||||||
/// Channel to communicate with the background task
|
|
||||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BlockAllocator {
|
|
||||||
pub(crate) fn new(
|
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
block_size: u32,
|
|
||||||
prefix_caching: bool,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
) -> Self {
|
|
||||||
// Create channel
|
|
||||||
let (sender, receiver) = mpsc::unbounded_channel();
|
|
||||||
|
|
||||||
// Launch background queue task
|
|
||||||
tokio::spawn(block_allocator_task(
|
|
||||||
max_batch_total_tokens / block_size,
|
|
||||||
block_size,
|
|
||||||
prefix_caching,
|
|
||||||
window_size,
|
|
||||||
receiver,
|
|
||||||
));
|
|
||||||
|
|
||||||
Self {
|
|
||||||
block_allocator: sender,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn allocate(
|
|
||||||
&self,
|
|
||||||
tokens: u32,
|
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
|
||||||
) -> Option<BlockAllocation> {
|
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
|
||||||
self.block_allocator
|
|
||||||
.send(BlockAllocatorCommand::Allocate {
|
|
||||||
tokens,
|
|
||||||
prefill_tokens,
|
|
||||||
response_sender,
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
response_receiver.await.unwrap().map(|mut allocation| {
|
|
||||||
allocation.block_allocator = Some(self.clone());
|
|
||||||
allocation
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
|
||||||
self.block_allocator
|
|
||||||
.send(BlockAllocatorCommand::Free {
|
|
||||||
allocation_id,
|
|
||||||
blocks,
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn block_allocator_task(
|
|
||||||
blocks: u32,
|
|
||||||
block_size: u32,
|
|
||||||
prefix_caching: bool,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
|
||||||
) {
|
|
||||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
|
||||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
|
||||||
} else {
|
|
||||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
|
||||||
};
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
|
||||||
match cmd {
|
|
||||||
BlockAllocatorCommand::Free {
|
|
||||||
blocks,
|
|
||||||
allocation_id,
|
|
||||||
} => allocator.free(blocks, allocation_id),
|
|
||||||
BlockAllocatorCommand::Allocate {
|
|
||||||
tokens,
|
|
||||||
prefill_tokens,
|
|
||||||
response_sender,
|
|
||||||
} => {
|
|
||||||
response_sender
|
|
||||||
.send(allocator.allocate(tokens, prefill_tokens))
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum BlockAllocatorCommand {
|
|
||||||
Free {
|
|
||||||
blocks: Vec<u32>,
|
|
||||||
allocation_id: u64,
|
|
||||||
},
|
|
||||||
Allocate {
|
|
||||||
tokens: u32,
|
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
|
||||||
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Allocator {
|
|
||||||
fn allocate(
|
|
||||||
&mut self,
|
|
||||||
tokens: u32,
|
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
|
||||||
) -> Option<BlockAllocation>;
|
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
|
||||||
}
|
|
||||||
pub struct SimpleAllocator {
|
|
||||||
free_blocks: Vec<u32>,
|
|
||||||
block_size: u32,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SimpleAllocator {
|
|
||||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
|
||||||
SimpleAllocator {
|
|
||||||
block_size,
|
|
||||||
// Block 0 is reserved for health checks
|
|
||||||
free_blocks: (1..blocks).collect(),
|
|
||||||
window_size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Allocator for SimpleAllocator {
|
|
||||||
fn allocate(
|
|
||||||
&mut self,
|
|
||||||
tokens: u32,
|
|
||||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
|
||||||
) -> Option<BlockAllocation> {
|
|
||||||
// Apply window size
|
|
||||||
let (required_blocks, repeats) = {
|
|
||||||
let (tokens, repeats) = match self.window_size {
|
|
||||||
None => (tokens, 1),
|
|
||||||
Some(window_size) => {
|
|
||||||
let repeats = (tokens + window_size - 1) / window_size;
|
|
||||||
let tokens = core::cmp::min(tokens, window_size);
|
|
||||||
(tokens, repeats as usize)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Pad to a multiple of block size
|
|
||||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
|
||||||
(required_blocks, repeats)
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokens = tokens as usize;
|
|
||||||
if required_blocks > self.free_blocks.len() as u32 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let blocks = self
|
|
||||||
.free_blocks
|
|
||||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
|
||||||
let mut slots =
|
|
||||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
|
||||||
|
|
||||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
|
||||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
|
||||||
slots.push(s);
|
|
||||||
if slots.len() == tokens {
|
|
||||||
break 'slots;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(BlockAllocation {
|
|
||||||
allocation_id: 0,
|
|
||||||
blocks,
|
|
||||||
slots,
|
|
||||||
prefix_len: 0,
|
|
||||||
block_allocator: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
|
||||||
self.free_blocks.extend(blocks)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +1,9 @@
|
|||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
use crate::client::{pb, Chunk};
|
use crate::client::pb;
|
||||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
use base64::engine::general_purpose::STANDARD;
|
|
||||||
use base64::Engine;
|
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use pb::generate::v3::*;
|
use pb::generate::v2::*;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
@ -47,7 +45,7 @@ impl Client {
|
|||||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
ClientError::Connection("Server does not support v3 interface".to_string())
|
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||||
})?;
|
})?;
|
||||||
let urls = response
|
let urls = response
|
||||||
.into_inner()
|
.into_inner()
|
||||||
@ -119,23 +117,6 @@ impl Client {
|
|||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
|
||||||
input_chunks
|
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
|
||||||
if n_tokens == 0 {
|
|
||||||
input_chunks.push(
|
|
||||||
Chunk::Image(Image {
|
|
||||||
// Safe unwrap, because we control the data.
|
|
||||||
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
|
||||||
mimetype: "image/jpeg;base64".to_string(),
|
|
||||||
})
|
|
||||||
.into(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
|
||||||
// been updated to support chunks.
|
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
@ -149,16 +130,8 @@ impl Client {
|
|||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
add_special_tokens: true,
|
|
||||||
input_chunks: Some(Input {
|
|
||||||
chunks: input_chunks,
|
|
||||||
}),
|
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
|
||||||
blocks: vec![],
|
|
||||||
slots: vec![],
|
|
||||||
prefix_len: 0,
|
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
@ -180,7 +153,6 @@ impl Client {
|
|||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
@ -194,8 +166,7 @@ impl Client {
|
|||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_input_length,
|
max_tokens: 0,
|
||||||
max_blocks: 0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
@ -12,10 +12,9 @@ mod grpc_client;
|
|||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use grpc_client::Client;
|
pub use grpc_client::Client;
|
||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v2::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
StoppingCriteriaParameters,
|
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
@ -64,13 +63,6 @@ impl From<transport::Error> for ClientError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Small convenience re-wrapping of `Chunk`.
|
|
||||||
impl From<Chunk> for InputChunk {
|
|
||||||
fn from(chunk: Chunk) -> Self {
|
|
||||||
InputChunk { chunk: Some(chunk) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
use crate::client::{ClientError, Result};
|
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
|
use crate::client::{ClientError, Result};
|
||||||
use crate::client::{Health, ShardInfo};
|
use crate::client::{Health, ShardInfo};
|
||||||
|
|
||||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
|
use crate::client::InfoResponse;
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use crate::client::{Chunk, InfoResponse, Input};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
@ -218,11 +218,7 @@ impl Health for ShardedClient {
|
|||||||
let liveness_request = Request {
|
let liveness_request = Request {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
input_chunks: Some(Input {
|
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
|
||||||
}),
|
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
add_special_tokens: true,
|
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
@ -243,18 +239,12 @@ impl Health for ShardedClient {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
// Block 0 is reserved for health checks
|
|
||||||
blocks: vec![0],
|
|
||||||
slots: (0..16).collect(),
|
|
||||||
prefix_len: 0,
|
|
||||||
adapter_id: None,
|
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
requests: vec![liveness_request],
|
requests: vec![liveness_request],
|
||||||
size: 1,
|
size: 1,
|
||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
mod backend;
|
mod backend;
|
||||||
pub mod block_allocator;
|
|
||||||
mod client;
|
mod client;
|
||||||
mod queue;
|
mod queue;
|
||||||
pub mod radix;
|
|
||||||
|
|
||||||
use crate::client::{ClientError, ShardedClient};
|
use crate::client::{ClientError, ShardedClient};
|
||||||
pub(crate) use backend::BackendV3;
|
pub(crate) use backend::BackendV2;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
@ -41,7 +39,7 @@ pub async fn connect_backend(
|
|||||||
max_batch_total_tokens: Option<u32>,
|
max_batch_total_tokens: Option<u32>,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||||
// Helper function
|
// Helper function
|
||||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
match max_supported_batch_total_tokens {
|
match max_supported_batch_total_tokens {
|
||||||
@ -65,7 +63,7 @@ pub async fn connect_backend(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(max_supported_batch_total_tokens)
|
Ok(max_supported_batch_total_tokens)
|
||||||
@ -75,16 +73,16 @@ pub async fn connect_backend(
|
|||||||
|
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.map_err(V3Error::Connection)?;
|
.map_err(V2Error::Connection)?;
|
||||||
|
|
||||||
// server is running on v3
|
// server is running on v2
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.map_err(V3Error::Cache)?;
|
.map_err(V2Error::Cache)?;
|
||||||
// Get info from the shard
|
// Get info from the shard
|
||||||
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
@ -97,7 +95,7 @@ pub async fn connect_backend(
|
|||||||
max_batch_size,
|
max_batch_size,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(V3Error::Warmup)?,
|
.map_err(V2Error::Warmup)?,
|
||||||
)?;
|
)?;
|
||||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
@ -111,7 +109,7 @@ pub async fn connect_backend(
|
|||||||
speculate: shard_info.speculate as usize,
|
speculate: shard_info.speculate as usize,
|
||||||
};
|
};
|
||||||
|
|
||||||
let backend = BackendV3::new(
|
let backend = BackendV2::new(
|
||||||
sharded_client,
|
sharded_client,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
@ -129,7 +127,7 @@ pub async fn connect_backend(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum V3Error {
|
pub enum V2Error {
|
||||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
Cache(ClientError),
|
Cache(ClientError),
|
||||||
#[error("Unable to connect to the Python model shards: {0}")]
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use text_generation_router::{server, usage_stats};
|
use text_generation_router::{server, usage_stats};
|
||||||
use text_generation_router_v3::{connect_backend, V3Error};
|
use text_generation_router_v2::{connect_backend, V2Error};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
@ -204,7 +204,7 @@ enum RouterError {
|
|||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
ArgumentValidation(String),
|
ArgumentValidation(String),
|
||||||
#[error("Backend failed: {0}")]
|
#[error("Backend failed: {0}")]
|
||||||
Backend(#[from] V3Error),
|
Backend(#[from] V2Error),
|
||||||
#[error("WebServer error: {0}")]
|
#[error("WebServer error: {0}")]
|
||||||
WebServer(#[from] server::WebServerError),
|
WebServer(#[from] server::WebServerError),
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
@ -1,20 +1,17 @@
|
|||||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
|
||||||
use crate::client;
|
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
ValidStoppingParameters,
|
|
||||||
};
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
|
|
||||||
/// Queue entry
|
/// Queue entry
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -31,8 +28,6 @@ pub(crate) struct Entry {
|
|||||||
pub queue_time: Instant,
|
pub queue_time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
/// Block Allocation
|
|
||||||
pub block_allocation: Option<BlockAllocation>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
@ -46,10 +41,8 @@ impl Queue {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
prefix_caching: bool,
|
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -58,17 +51,14 @@ impl Queue {
|
|||||||
tokio::spawn(queue_task(
|
tokio::spawn(queue_task(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
prefix_caching,
|
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Append an entry to the queue
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
// Send append command to the background task managing the state
|
// Send append command to the background task managing the state
|
||||||
@ -111,20 +101,11 @@ impl Queue {
|
|||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
prefix_caching: bool,
|
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
||||||
requires_padding,
|
|
||||||
block_size,
|
|
||||||
prefix_caching,
|
|
||||||
window_size,
|
|
||||||
speculate,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
);
|
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -139,14 +120,12 @@ async fn queue_task(
|
|||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span,
|
span,
|
||||||
} => {
|
} => span.in_scope(|| {
|
||||||
let next_batch = state
|
let next_batch =
|
||||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||||
.instrument(span)
|
|
||||||
.await;
|
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||||
}
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,6 +142,9 @@ struct State {
|
|||||||
/// Id of the next batch
|
/// Id of the next batch
|
||||||
next_batch_id: u64,
|
next_batch_id: u64,
|
||||||
|
|
||||||
|
/// Whether the model is using padding
|
||||||
|
requires_padding: bool,
|
||||||
|
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
@ -171,37 +153,23 @@ struct State {
|
|||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
/// Paged Attention Block Allocation
|
|
||||||
block_allocator: Option<BlockAllocator>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(
|
fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
prefix_caching: bool,
|
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = (!requires_padding).then(|| {
|
|
||||||
BlockAllocator::new(
|
|
||||||
max_batch_total_tokens,
|
|
||||||
block_size,
|
|
||||||
prefix_caching,
|
|
||||||
window_size,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
block_allocator,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,7 +185,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
async fn next_batch(
|
fn next_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
max_size: Option<usize>,
|
max_size: Option<usize>,
|
||||||
@ -252,14 +220,16 @@ impl State {
|
|||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(Span::current());
|
next_batch_span.follows_from(Span::current());
|
||||||
|
|
||||||
let mut batch = Vec::with_capacity(self.entries.len());
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
|
let mut batch_entries =
|
||||||
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
let mut max_input_length = 0;
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
let mut max_blocks = 0;
|
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
@ -268,113 +238,44 @@ impl State {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let block_allocation = match &self.block_allocator {
|
if self.requires_padding {
|
||||||
None => {
|
// We pad to max input length in the Python shards
|
||||||
// We pad to max input length in the Python shards
|
// We need to take these padding tokens into the equation
|
||||||
// We need to take these padding tokens into the equation
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
} else {
|
||||||
|
// pad to block size
|
||||||
|
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||||
|
/ self.block_size)
|
||||||
|
* self.block_size;
|
||||||
|
}
|
||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
if self.requires_padding {
|
||||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
} else {
|
||||||
|
let max_new_tokens = match self.window_size {
|
||||||
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
Some(window_size) => min(
|
||||||
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
// pad to block size
|
||||||
// Entry is over budget
|
decode_tokens +=
|
||||||
// Add it back to the front
|
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
}
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
break 'entry_loop;
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
Some(_block_allocator) => {
|
|
||||||
prefill_tokens += entry.request.input_length;
|
|
||||||
let max_new_tokens = match self.window_size {
|
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
decode_tokens += max_new_tokens;
|
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||||
{
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tokens = entry.request.input_length
|
|
||||||
+ entry.request.stopping_parameters.max_new_tokens
|
|
||||||
+ self.speculate
|
|
||||||
- 1;
|
|
||||||
|
|
||||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
|
||||||
// So no input_ids for the radix tree.
|
|
||||||
let input_ids = if entry.request.decoder_input_details {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
entry.request.input_ids.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
Some((tokens, input_ids))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
batch.push((id, entry, block_allocation));
|
|
||||||
if Some(batch.len()) == max_size {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch.is_empty() {
|
|
||||||
tracing::debug!("Filterered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
|
||||||
// Check if our batch is big enough
|
|
||||||
if let Some(min_size) = min_size {
|
|
||||||
// Batch is too small
|
|
||||||
if batch.len() < min_size {
|
|
||||||
// Add back entries to the queue in the correct order
|
|
||||||
for (id, entry, _) in batch.into_iter().rev() {
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
}
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
|
||||||
let mut batch_entries =
|
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
|
||||||
|
|
||||||
for (id, mut entry, block_allocation) in batch {
|
|
||||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
|
||||||
(block_allocation, &self.block_allocator)
|
|
||||||
{
|
|
||||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
|
||||||
match block_allocator.allocate(tokens, input_ids).await {
|
|
||||||
None => {
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Some(block_allocation) => {
|
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
|
||||||
Some(block_allocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
tracing::debug!("Accepting entry");
|
tracing::debug!("Accepting entry");
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
@ -384,40 +285,11 @@ impl State {
|
|||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
let (blocks, slots, prefix_len) = match &block_allocation {
|
|
||||||
None => (Vec::new(), Vec::new(), 0),
|
|
||||||
Some(block_allocation) => (
|
|
||||||
block_allocation.blocks.clone(),
|
|
||||||
block_allocation.slots.clone(),
|
|
||||||
block_allocation.prefix_len,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
entry.block_allocation = block_allocation;
|
|
||||||
|
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
input_chunks: Some(client::Input {
|
|
||||||
chunks: entry
|
|
||||||
.request
|
|
||||||
.inputs
|
|
||||||
.clone()
|
|
||||||
.into_iter()
|
|
||||||
.map(|c| client::InputChunk {
|
|
||||||
chunk: Some(match c {
|
|
||||||
Chunk::Text(text) => client::Chunk::Text(text),
|
|
||||||
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
|
||||||
data: image.data,
|
|
||||||
mimetype: image.mimetype,
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
}),
|
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
add_special_tokens: entry.request.add_special_tokens,
|
|
||||||
parameters: Some(NextTokenChooserParameters::from(
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
entry.request.parameters.clone(),
|
entry.request.parameters.clone(),
|
||||||
)),
|
)),
|
||||||
@ -425,23 +297,39 @@ impl State {
|
|||||||
entry.request.stopping_parameters.clone(),
|
entry.request.stopping_parameters.clone(),
|
||||||
)),
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
|
||||||
slots,
|
|
||||||
prefix_len,
|
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
// Insert in batch_entries IntMap
|
// Insert in batch_entries IntMap
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
|
// Check if max_size
|
||||||
|
if Some(batch_requests.len()) == max_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty batch
|
// Empty batch
|
||||||
if batch_requests.is_empty() {
|
if batch_requests.is_empty() {
|
||||||
tracing::debug!("Filterered out all entries");
|
tracing::debug!("Filtered out all entries");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if our batch is big enough
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
// Batch is too small
|
||||||
|
if batch_requests.len() < min_size {
|
||||||
|
// Add back entries to the queue in the correct order
|
||||||
|
for r in batch_requests.into_iter().rev() {
|
||||||
|
let id = r.id;
|
||||||
|
let entry = batch_entries.remove(&id).unwrap();
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
@ -451,7 +339,6 @@ impl State {
|
|||||||
requests: batch_requests,
|
requests: batch_requests,
|
||||||
size,
|
size,
|
||||||
max_tokens: (prefill_tokens + decode_tokens),
|
max_tokens: (prefill_tokens + decode_tokens),
|
||||||
max_blocks,
|
|
||||||
};
|
};
|
||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
@ -516,9 +403,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::sync::Arc;
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
@ -560,14 +446,13 @@ mod tests {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
block_allocation: None,
|
|
||||||
};
|
};
|
||||||
(entry, receiver_tx)
|
(entry, receiver_tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test]
|
||||||
async fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
@ -581,23 +466,23 @@ mod tests {
|
|||||||
assert_eq!(id, 0);
|
assert_eq!(id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test]
|
||||||
async fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
assert!(state.next_batch(None, None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test]
|
||||||
async fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -613,7 +498,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
@ -621,15 +506,15 @@ mod tests {
|
|||||||
assert_eq!(id, 2);
|
assert_eq!(id, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test]
|
||||||
async fn test_next_batch_max_size() {
|
fn test_next_batch_max_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
@ -641,15 +526,15 @@ mod tests {
|
|||||||
assert_eq!(state.next_batch_id, 1);
|
assert_eq!(state.next_batch_id, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test]
|
||||||
async fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
let mut state = State::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -662,7 +547,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -676,14 +561,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
@ -691,7 +576,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -724,7 +609,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -740,7 +625,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -765,7 +650,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
let queue = Queue::new(false, 1, None, 2);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -784,7 +669,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, None, 0);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -1,876 +0,0 @@
|
|||||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
|
||||||
use std::hash::{Hash, Hasher};
|
|
||||||
use std::{
|
|
||||||
collections::{BTreeSet, HashMap},
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
fn hash(slice: &[u32]) -> u64 {
|
|
||||||
assert!(!slice.is_empty());
|
|
||||||
if slice.len() == 1 {
|
|
||||||
slice[0] as u64
|
|
||||||
} else {
|
|
||||||
let mut s = std::hash::DefaultHasher::new();
|
|
||||||
slice.hash(&mut s);
|
|
||||||
s.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct RadixAllocator {
|
|
||||||
allocation_id: u64,
|
|
||||||
|
|
||||||
allocations: HashMap<u64, RadixAllocation>,
|
|
||||||
|
|
||||||
cache_blocks: RadixTrie,
|
|
||||||
|
|
||||||
/// Blocks that are immediately available for allocation.
|
|
||||||
free_blocks: Vec<u32>,
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
// This isn't used because the prefix need to match without the windowing
|
|
||||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
|
||||||
window_size: Option<u32>,
|
|
||||||
|
|
||||||
block_size: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RadixAllocator {
|
|
||||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
|
||||||
RadixAllocator {
|
|
||||||
allocation_id: 0,
|
|
||||||
allocations: HashMap::new(),
|
|
||||||
cache_blocks: RadixTrie::new(block_size as usize),
|
|
||||||
|
|
||||||
// Block 0 is reserved for health checks.
|
|
||||||
free_blocks: (1..n_blocks).collect(),
|
|
||||||
window_size,
|
|
||||||
block_size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
|
||||||
if self.free_blocks.len() < n_blocks_needed {
|
|
||||||
// This is a bit annoying, we first extend the free list and then
|
|
||||||
// split it off again below. This is because we need to put it on
|
|
||||||
// the free list if we cannot allocate enough blocks. This is only
|
|
||||||
// temporary, the trie needs to be able to report whether it can
|
|
||||||
// allocate the requested amount. Just not implemented yet.
|
|
||||||
tracing::debug!(
|
|
||||||
"Free blocks {} need {n_blocks_needed}",
|
|
||||||
self.free_blocks.len()
|
|
||||||
);
|
|
||||||
self.free_blocks.extend(
|
|
||||||
self.cache_blocks
|
|
||||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.free_blocks.len() >= n_blocks_needed {
|
|
||||||
Some(
|
|
||||||
self.free_blocks
|
|
||||||
.split_off(self.free_blocks.len() - n_blocks_needed),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocator trait
|
|
||||||
impl Allocator for RadixAllocator {
|
|
||||||
fn allocate(
|
|
||||||
&mut self,
|
|
||||||
tokens: u32,
|
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
|
||||||
) -> Option<BlockAllocation> {
|
|
||||||
let mut blocks = vec![];
|
|
||||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
|
||||||
let node_id = self
|
|
||||||
.cache_blocks
|
|
||||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
|
||||||
node_id
|
|
||||||
} else {
|
|
||||||
self.cache_blocks.root_id()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Even if this allocation fails below, we need to increase he
|
|
||||||
// refcount to ensure that the prefix that was found is not evicted.
|
|
||||||
self.cache_blocks
|
|
||||||
.incref(prefix_node)
|
|
||||||
.expect("Failed to increment refcount");
|
|
||||||
|
|
||||||
let prefix_len = blocks.len() * self.block_size as usize;
|
|
||||||
let suffix_len = tokens - prefix_len as u32;
|
|
||||||
|
|
||||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
|
||||||
|
|
||||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
|
||||||
|
|
||||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
|
||||||
None => {
|
|
||||||
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
|
|
||||||
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
|
|
||||||
tracing::debug!("Block size {}", self.block_size);
|
|
||||||
self.cache_blocks
|
|
||||||
.decref(prefix_node)
|
|
||||||
.expect("Failed to decrement refcount");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1:1 mapping of blocks and slots.
|
|
||||||
let slots = if self.block_size == 1 {
|
|
||||||
blocks.clone()
|
|
||||||
} else {
|
|
||||||
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
|
||||||
'slots: for block_id in &blocks {
|
|
||||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
|
||||||
slots.push(s);
|
|
||||||
if slots.len() as u32 == tokens {
|
|
||||||
break 'slots;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slots
|
|
||||||
};
|
|
||||||
|
|
||||||
let allocation = RadixAllocation {
|
|
||||||
prefix_node,
|
|
||||||
cached_prefix_len: prefix_len,
|
|
||||||
prefill_tokens: prefill_tokens.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.allocation_id += 1;
|
|
||||||
self.allocations.insert(self.allocation_id, allocation);
|
|
||||||
|
|
||||||
Some(BlockAllocation {
|
|
||||||
allocation_id: self.allocation_id,
|
|
||||||
block_allocator: None,
|
|
||||||
blocks,
|
|
||||||
slots,
|
|
||||||
prefix_len: prefix_len as u32,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
|
||||||
let allocation = match self.allocations.remove(&allocation_id) {
|
|
||||||
Some(allocation) => allocation,
|
|
||||||
None => unreachable!("Tried to free an unknown allocation."),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.cache_blocks
|
|
||||||
.decref(allocation.prefix_node)
|
|
||||||
.expect("Failed to decrement refcount");
|
|
||||||
|
|
||||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
|
||||||
let prefill_tokens = prefill_tokens.as_slice();
|
|
||||||
|
|
||||||
// If there are prefill tokens that did not come from the cache,
|
|
||||||
// add them to the cache.
|
|
||||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
|
||||||
let aligned =
|
|
||||||
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
|
||||||
if aligned > 0 {
|
|
||||||
let prefix_len = self
|
|
||||||
.cache_blocks
|
|
||||||
.insert(
|
|
||||||
&prefill_tokens[..aligned],
|
|
||||||
&blocks[..aligned / self.block_size as usize],
|
|
||||||
)
|
|
||||||
// Unwrap, failing is a programming error.
|
|
||||||
.expect("Failed to store prefill tokens");
|
|
||||||
// We can have a prefill with the following structure:
|
|
||||||
//
|
|
||||||
// |---| From the prefix cache.
|
|
||||||
// A B C D E F G
|
|
||||||
//|--------| Found in the trie during insertion.
|
|
||||||
//
|
|
||||||
// This means that while processing this request there was a
|
|
||||||
// partially overlapping request that had A..=E in its
|
|
||||||
// prefill. In this case we need to free the blocks D E.
|
|
||||||
if prefix_len > allocation.cached_prefix_len {
|
|
||||||
self.free_blocks.extend(
|
|
||||||
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
|
||||||
..prefix_len / self.block_size as usize],
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Free non-prefill blocks.
|
|
||||||
self.free_blocks
|
|
||||||
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
|
||||||
} else {
|
|
||||||
self.free_blocks.extend(blocks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RadixAllocation {
|
|
||||||
prefix_node: NodeId,
|
|
||||||
cached_prefix_len: usize,
|
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Radix trie that is heavily inspired by radix attention from sglang.
|
|
||||||
//
|
|
||||||
// The trie is optimized for prefix caching:
|
|
||||||
//
|
|
||||||
// - A normal radix trie stores discrete values. In this radix trie,
|
|
||||||
// inserting *abc* with value *xyz* will also enable lookup for
|
|
||||||
// *a* (*x*) and *ab* (*xy*).
|
|
||||||
// - As a result, every value is required to have the same length as
|
|
||||||
// the key.
|
|
||||||
// - We store additional information in each node, such as last access
|
|
||||||
// time and a reference count.
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TrieError {
|
|
||||||
InvalidNodeId,
|
|
||||||
RefCountUnderflow,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type NodeId = DefaultKey;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct RadixTrie {
|
|
||||||
/// Identifier of the root nod.
|
|
||||||
root: DefaultKey,
|
|
||||||
|
|
||||||
/// Leave node identifiers ordered by increasing recency.
|
|
||||||
leaves: BTreeSet<(u64, NodeId)>,
|
|
||||||
|
|
||||||
/// All trie nodes.
|
|
||||||
nodes: SlotMap<NodeId, TrieNode>,
|
|
||||||
|
|
||||||
/// Time as a monotonically increating counter to avoid the system
|
|
||||||
/// call that a real time lookup would require.
|
|
||||||
time: u64,
|
|
||||||
|
|
||||||
/// All blocks need to be aligned with this
|
|
||||||
block_size: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RadixTrie {
|
|
||||||
/// Construct a new radix trie.
|
|
||||||
pub fn new(block_size: usize) -> Self {
|
|
||||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
|
||||||
let mut nodes = SlotMap::new();
|
|
||||||
let root = nodes.insert(root);
|
|
||||||
RadixTrie {
|
|
||||||
leaves: BTreeSet::new(),
|
|
||||||
nodes,
|
|
||||||
root,
|
|
||||||
time: 0,
|
|
||||||
block_size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find the prefix of the given tokens.
|
|
||||||
///
|
|
||||||
/// The blocks corresponding to the part of the prefix that could be found
|
|
||||||
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
|
||||||
/// Returns the identifier of the trie node that contains the longest
|
|
||||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
|
||||||
/// reference count.
|
|
||||||
///
|
|
||||||
/// Using this method will update the access time of the traversed nodes.
|
|
||||||
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
|
||||||
self.time += 1;
|
|
||||||
self.find_(self.root, key, blocks)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find worker.
|
|
||||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
|
||||||
let node = &self.nodes[node_id];
|
|
||||||
|
|
||||||
if key.len() >= self.block_size {
|
|
||||||
let node_key = hash(&key[..self.block_size]);
|
|
||||||
if let Some(&child_id) = node.children.get(&node_key) {
|
|
||||||
self.update_access_time(child_id);
|
|
||||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
|
||||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
|
||||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
|
||||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
|
||||||
|
|
||||||
let key = &key[shared_prefix_len..];
|
|
||||||
if !key.is_empty() {
|
|
||||||
node_id = self.find_(child_id, key, blocks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decrease the reference count of a node.
|
|
||||||
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
|
||||||
// We don't care about refcounting for root, since it will never
|
|
||||||
// be evicted.
|
|
||||||
if node_id == self.root {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let node = self
|
|
||||||
.nodes
|
|
||||||
.get_mut(node_id)
|
|
||||||
.ok_or(TrieError::InvalidNodeId)?;
|
|
||||||
if node.ref_count == 0 {
|
|
||||||
return Err(TrieError::RefCountUnderflow);
|
|
||||||
}
|
|
||||||
|
|
||||||
node.ref_count -= 1;
|
|
||||||
if node.ref_count == 0 {
|
|
||||||
assert!(
|
|
||||||
node.children.is_empty(),
|
|
||||||
"Nodes with children must have refcount > 0"
|
|
||||||
);
|
|
||||||
|
|
||||||
self.leaves.insert((node.last_accessed, node_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Increase the reference count of a node.
|
|
||||||
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
|
||||||
if node_id == self.root {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let node = self
|
|
||||||
.nodes
|
|
||||||
.get_mut(node_id)
|
|
||||||
.ok_or(TrieError::InvalidNodeId)?;
|
|
||||||
if node.ref_count == 0 {
|
|
||||||
self.leaves.remove(&(node.last_accessed, node_id));
|
|
||||||
}
|
|
||||||
node.ref_count += 1;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Evict `n_blocks` from the trie.
|
|
||||||
///
|
|
||||||
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
|
||||||
/// not enough blocks could be evicted.
|
|
||||||
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
|
||||||
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
|
||||||
// it's a programming error in the trie implementation, not a user
|
|
||||||
// error caused by e.g. an invalid argument.
|
|
||||||
|
|
||||||
// TODO: add some bookkeeping in the future to check whether we can
|
|
||||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
|
||||||
// evicting prefixes from the cache in such a case.
|
|
||||||
let mut evicted = Vec::new();
|
|
||||||
tracing::debug!("Evicting in search of {n_blocks}");
|
|
||||||
|
|
||||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
|
||||||
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
|
||||||
tracing::debug!("Evicting node {node_id:?} ");
|
|
||||||
|
|
||||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
|
||||||
assert_eq!(
|
|
||||||
node.ref_count, 0,
|
|
||||||
"Leaf must have refcount of 0, got {}",
|
|
||||||
node.ref_count
|
|
||||||
);
|
|
||||||
|
|
||||||
if blocks_needed >= node.blocks.len() {
|
|
||||||
// We need to evict the whole node if we need more blocks than it has.
|
|
||||||
let node = self.remove_node(node_id);
|
|
||||||
evicted.extend(node.blocks);
|
|
||||||
|
|
||||||
if evicted.len() >= n_blocks {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// The node has more blocks than needed, so we'll just remove
|
|
||||||
// the required number of blocks and leave the remaining blocks
|
|
||||||
// untouched.
|
|
||||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
|
||||||
|
|
||||||
let truncate_blocks = node.blocks.len() - blocks_needed;
|
|
||||||
let truncate_tokens = truncate_blocks * self.block_size;
|
|
||||||
node.key.truncate(truncate_tokens);
|
|
||||||
evicted.extend(node.blocks.split_off(truncate_blocks));
|
|
||||||
self.leaves.insert((last_access, node_id));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
evicted
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert a prefill along with its blocks.
|
|
||||||
///
|
|
||||||
/// This method returns the length of the prefix that was already
|
|
||||||
/// in the trie. E.g. if the length is 10, this means that for
|
|
||||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
|
||||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
|
||||||
self.time += 1;
|
|
||||||
let common = self.insert_(self.root, tokens, blocks)?;
|
|
||||||
Ok(common)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insertion worker.
|
|
||||||
fn insert_(
|
|
||||||
&mut self,
|
|
||||||
node_id: NodeId,
|
|
||||||
tokens: &[u32],
|
|
||||||
blocks: &[u32],
|
|
||||||
) -> Result<usize, TrieError> {
|
|
||||||
// TODO: in the future we may want to check that the blocks match for
|
|
||||||
// the part of the prefix that is already in the trie to detect
|
|
||||||
// mismatches.
|
|
||||||
|
|
||||||
assert_eq!(tokens.len(), blocks.len() * self.block_size);
|
|
||||||
|
|
||||||
let node_key = hash(&tokens[..self.block_size]);
|
|
||||||
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
|
|
||||||
self.update_access_time(child_id);
|
|
||||||
let child = self
|
|
||||||
.nodes
|
|
||||||
.get_mut(child_id)
|
|
||||||
// Unwrap here, since failure is a bug.
|
|
||||||
.expect("Child node does not exist");
|
|
||||||
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
|
||||||
|
|
||||||
// We are done, the prefix is already in the trie.
|
|
||||||
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
|
||||||
return Ok(shared_prefix_len);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node's prefix is a prefix of the insertion prefix.
|
|
||||||
if shared_prefix_len == child.key.len() {
|
|
||||||
return Ok(shared_prefix_len
|
|
||||||
+ self.insert_(
|
|
||||||
child_id,
|
|
||||||
&tokens[shared_prefix_len..],
|
|
||||||
&blocks[shared_prefix_len / self.block_size..],
|
|
||||||
)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node's prefix and the insertion prefix only match partially,
|
|
||||||
// split the node to just contain the matching part. Then insert the
|
|
||||||
// remainder of the prefix into the node again
|
|
||||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
|
||||||
let key = &tokens[shared_prefix_len..];
|
|
||||||
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
|
||||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
|
||||||
} else {
|
|
||||||
self.add_node(node_id, tokens, blocks);
|
|
||||||
Ok(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
|
||||||
// We have to make the current node a child to ensure that its
|
|
||||||
// properties and node id stay the same.
|
|
||||||
|
|
||||||
// This funcion unwraps, an invalid node_id is a programming error.
|
|
||||||
|
|
||||||
let node = self
|
|
||||||
.nodes
|
|
||||||
.get_mut(node_id)
|
|
||||||
.expect("Node to-be split does not exist");
|
|
||||||
let mut parent_key = node.key.split_off(prefix_len);
|
|
||||||
let prefix_blocks = prefix_len / self.block_size;
|
|
||||||
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
|
|
||||||
|
|
||||||
// Move first part of the prefix to the parent. We swap to avoid
|
|
||||||
// an allocation + copy for both splits of the key/blocks.
|
|
||||||
std::mem::swap(&mut node.key, &mut parent_key);
|
|
||||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
|
||||||
|
|
||||||
let node_key = hash(&node.key[..self.block_size]);
|
|
||||||
|
|
||||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
|
||||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
|
||||||
self.add_node_to_parent(parent_id, node_key, node_id);
|
|
||||||
|
|
||||||
// Reborrow to make the borrow checker happy.
|
|
||||||
let node = self
|
|
||||||
.nodes
|
|
||||||
.get_mut(node_id)
|
|
||||||
.expect("Node to-be split does not exist");
|
|
||||||
node.parent = Some(parent_id);
|
|
||||||
|
|
||||||
parent_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a node and add it to the parent.
|
|
||||||
fn add_node(
|
|
||||||
&mut self,
|
|
||||||
parent_id: NodeId,
|
|
||||||
key: impl Into<Vec<u32>>,
|
|
||||||
blocks: impl Into<Vec<u32>>,
|
|
||||||
) -> NodeId {
|
|
||||||
let key = key.into();
|
|
||||||
let blocks = blocks.into();
|
|
||||||
let first = hash(&key[..self.block_size]);
|
|
||||||
|
|
||||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
|
||||||
let child_id = self.nodes.insert(child);
|
|
||||||
|
|
||||||
self.add_node_to_parent(parent_id, first, child_id);
|
|
||||||
self.leaves.insert((self.time, child_id));
|
|
||||||
|
|
||||||
child_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a node to the parent.
|
|
||||||
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
|
|
||||||
// Unwrap here, passing in an unknown id is a programming error.
|
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
|
||||||
if parent.children.insert(hash, child_id).is_none() {
|
|
||||||
// Only increase reference count if child does not replace another child.
|
|
||||||
self.incref(parent_id)
|
|
||||||
.expect("Failed to increase parent refcount");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Remove a node from the trie.
|
|
||||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
|
||||||
// Unwrap here, passing in an unknown id is a programming error.
|
|
||||||
let node = self.nodes.remove(node_id).expect("Unknown node");
|
|
||||||
assert!(
|
|
||||||
node.children.is_empty(),
|
|
||||||
"Tried to remove a node with {} children",
|
|
||||||
node.children.len()
|
|
||||||
);
|
|
||||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
|
||||||
|
|
||||||
let node_key = hash(&node.key[..self.block_size]);
|
|
||||||
parent.children.remove(&node_key);
|
|
||||||
self.decref(parent_id)
|
|
||||||
.expect("Failed to decrease parent refcount");
|
|
||||||
node
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_access_time(&mut self, node_id: NodeId) {
|
|
||||||
// Unwrap here, passing in an unknown id is a programming error.
|
|
||||||
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
|
||||||
|
|
||||||
// Update the ordered leaves set if the node is a leave.
|
|
||||||
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
|
||||||
self.leaves.insert((self.time, node_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
node.last_accessed = self.time;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[doc(hidden)]
|
|
||||||
/// Print debugging output for the trie.
|
|
||||||
///
|
|
||||||
/// In contrast to `Debug` nicely formatted.
|
|
||||||
pub fn print_debug(&self) {
|
|
||||||
self.print_debug_(self.root, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
|
||||||
let node = &self.nodes[node_id];
|
|
||||||
eprintln!(
|
|
||||||
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
|
||||||
" ".repeat(indent),
|
|
||||||
node_id,
|
|
||||||
node.key,
|
|
||||||
node.blocks,
|
|
||||||
node.ref_count,
|
|
||||||
node.last_accessed,
|
|
||||||
node.parent,
|
|
||||||
node.children
|
|
||||||
);
|
|
||||||
for child_id in self.nodes[node_id].children.values() {
|
|
||||||
self.print_debug_(*child_id, indent + 2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn root_id(&self) -> DefaultKey {
|
|
||||||
self.root
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trie node.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct TrieNode {
|
|
||||||
blocks: Vec<u32>,
|
|
||||||
children: HashMap<u64, NodeId>,
|
|
||||||
key: Vec<u32>,
|
|
||||||
last_accessed: u64,
|
|
||||||
parent: Option<NodeId>,
|
|
||||||
ref_count: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TrieNode {
|
|
||||||
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
|
||||||
TrieNode {
|
|
||||||
children: HashMap::new(),
|
|
||||||
key,
|
|
||||||
blocks,
|
|
||||||
last_accessed,
|
|
||||||
parent,
|
|
||||||
ref_count: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
|
||||||
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
|
||||||
// NOTE: this is the case because the child node was chosen based on
|
|
||||||
// matching the first character of the key/prefix.
|
|
||||||
assert!(full > 0, "Prefixes must at least share 1 token");
|
|
||||||
(full / block_size) * block_size
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allocator_block_size() {
|
|
||||||
let mut cache = RadixAllocator::new(2, 12, None);
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
|
||||||
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
|
||||||
assert_eq!(allocation.prefix_len, 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allocator_block_size_non_aligned() {
|
|
||||||
let mut cache = RadixAllocator::new(2, 12, None);
|
|
||||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
|
||||||
|
|
||||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
|
||||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
|
||||||
assert_eq!(allocation.prefix_len, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allocator_reuses_prefixes() {
|
|
||||||
let mut cache = RadixAllocator::new(1, 12, None);
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation.blocks, allocation.slots);
|
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
|
||||||
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation.prefix_len, 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allocator_collects_older_prefixes_first() {
|
|
||||||
let mut cache = RadixAllocator::new(1, 7, None);
|
|
||||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
|
||||||
assert_eq!(allocation1.prefix_len, 0);
|
|
||||||
|
|
||||||
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
|
||||||
assert_eq!(allocation2.blocks, vec![1, 2]);
|
|
||||||
assert_eq!(allocation2.prefix_len, 0);
|
|
||||||
|
|
||||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
|
||||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
|
||||||
|
|
||||||
// We should get the blocks of the first allocation, since they are more recent.
|
|
||||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
|
||||||
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
|
||||||
assert_eq!(allocation3.prefix_len, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allocator_frees_fully_overlapping_prefills() {
|
|
||||||
let mut cache = RadixAllocator::new(1, 10, None);
|
|
||||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
|
|
||||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
|
||||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
|
||||||
|
|
||||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
|
||||||
assert_eq!(allocation3.prefix_len, 4);
|
|
||||||
|
|
||||||
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
|
||||||
assert_eq!(cache.free_blocks.len(), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn allocator_frees_partially_overlapping_prefills() {
|
|
||||||
let mut cache = RadixAllocator::new(1, 20, None);
|
|
||||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
|
||||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
|
||||||
assert_eq!(allocation1.prefix_len, 0);
|
|
||||||
|
|
||||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
|
||||||
|
|
||||||
let allocation2 = cache
|
|
||||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
|
||||||
assert_eq!(allocation2.prefix_len, 2);
|
|
||||||
|
|
||||||
let allocation3 = cache
|
|
||||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
|
||||||
assert_eq!(allocation3.prefix_len, 2);
|
|
||||||
|
|
||||||
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
|
||||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
|
||||||
|
|
||||||
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
|
||||||
assert_eq!(cache.free_blocks.len(), 11);
|
|
||||||
|
|
||||||
let allocation4 = cache
|
|
||||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
|
||||||
assert_eq!(allocation4.prefix_len, 6);
|
|
||||||
assert_eq!(cache.free_blocks.len(), 11);
|
|
||||||
|
|
||||||
let allocation5 = cache
|
|
||||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
|
||||||
assert_eq!(allocation5.prefix_len, 6);
|
|
||||||
assert_eq!(cache.free_blocks.len(), 11);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn trie_insertions_have_correct_prefix_len() {
|
|
||||||
let mut trie = RadixTrie::new(1);
|
|
||||||
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
|
||||||
|
|
||||||
// Already exists.
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
|
||||||
|
|
||||||
// Completely new at root-level
|
|
||||||
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
|
||||||
|
|
||||||
// Contains full prefix, but longer.
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
|
||||||
|
|
||||||
// Shares partial prefix, we need a split.
|
|
||||||
assert_eq!(
|
|
||||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
|
||||||
.unwrap(),
|
|
||||||
4
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn trie_insertions_block_size() {
|
|
||||||
let mut trie = RadixTrie::new(2);
|
|
||||||
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
|
||||||
|
|
||||||
// Already exists.
|
|
||||||
// But needs to be block_size aligned
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
|
||||||
|
|
||||||
// Completely new at root-level
|
|
||||||
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
|
||||||
|
|
||||||
// Contains full prefix, but longer.
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
|
||||||
|
|
||||||
// Shares partial prefix, we need a split.
|
|
||||||
assert_eq!(
|
|
||||||
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
|
||||||
.unwrap(),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn trie_get_returns_correct_blocks() {
|
|
||||||
let mut trie = RadixTrie::new(1);
|
|
||||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
|
||||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
|
||||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
|
||||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let mut blocks = Vec::new();
|
|
||||||
trie.find(&[0], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0]);
|
|
||||||
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[0, 1, 2], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1, 2]);
|
|
||||||
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[1, 2, 3], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![1, 2, 3]);
|
|
||||||
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[0, 1, 2, 3], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
|
||||||
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
|
||||||
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn trie_evict_removes_correct_blocks() {
|
|
||||||
let mut trie = RadixTrie::new(1);
|
|
||||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
|
||||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
|
||||||
.unwrap();
|
|
||||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
|
||||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
|
||||||
|
|
||||||
let mut blocks = Vec::new();
|
|
||||||
|
|
||||||
// Remove less than the leave blocks.
|
|
||||||
assert_eq!(trie.evict(1), vec![7]);
|
|
||||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
|
||||||
|
|
||||||
// Refresh other leaf.
|
|
||||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
|
||||||
trie.find(&[1, 2, 3], &mut blocks);
|
|
||||||
|
|
||||||
// Remove the leave blocks exactly.
|
|
||||||
assert_eq!(trie.evict(2), vec![5, 6]);
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
|
||||||
|
|
||||||
trie.find(&[1, 2, 3], &mut blocks);
|
|
||||||
|
|
||||||
// Remove more than the leave blocks.
|
|
||||||
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
|
||||||
blocks.clear();
|
|
||||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
|
||||||
assert_eq!(blocks, vec![0, 1]);
|
|
||||||
|
|
||||||
// Clear out the whole trie.
|
|
||||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
|
||||||
}
|
|
||||||
}
|
|
@ -8,9 +8,11 @@ use crate::{
|
|||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
};
|
};
|
||||||
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chat_template::ChatTemplate;
|
use chat_template::ChatTemplate;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
use futures::Stream;
|
||||||
use minijinja::ErrorKind;
|
use minijinja::ErrorKind;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -87,7 +89,14 @@ impl Infer {
|
|||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
|
||||||
|
),
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
@ -107,9 +116,18 @@ impl Infer {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_length = valid_request.input_length;
|
let input_length = valid_request.input_length;
|
||||||
let generation_stream = self.backend.schedule(valid_request)?;
|
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||||
|
|
||||||
Ok((permit, input_length, generation_stream))
|
// Wrap generation stream to update the backend health if the stream contains an error
|
||||||
|
let final_stream = stream! {
|
||||||
|
while let Some(response) = generation_stream.next().await {
|
||||||
|
yield response.inspect_err(|_err| {
|
||||||
|
self.backend_health.store(false, Ordering::SeqCst);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((permit, input_length, final_stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
/// Tokenizer the input
|
||||||
@ -278,13 +296,6 @@ impl Infer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Type alias for generation responses
|
|
||||||
pub(crate) type GenerateStreamResponse = (
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
u32, // input_length
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GeneratedText {
|
pub struct GeneratedText {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
mod queue;
|
|
||||||
mod scheduler;
|
|
||||||
|
|
||||||
pub(crate) use scheduler::BackendV2;
|
|
@ -1,675 +0,0 @@
|
|||||||
use crate::infer::{InferError, InferStreamResponse};
|
|
||||||
use crate::validation::{
|
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
|
||||||
};
|
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
|
||||||
use std::cmp::min;
|
|
||||||
use std::collections::VecDeque;
|
|
||||||
use text_generation_client::v2::{
|
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use text_generation_client::ChunksToString;
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
use tokio::time::Instant;
|
|
||||||
use tracing::{info_span, instrument, Span};
|
|
||||||
|
|
||||||
/// Queue entry
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct Entry {
|
|
||||||
/// Request
|
|
||||||
pub request: ValidGenerateRequest,
|
|
||||||
/// Response sender to communicate between the Infer struct and the batching_task
|
|
||||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
|
||||||
/// Span that will live as long as entry
|
|
||||||
pub span: Span,
|
|
||||||
/// Temporary span used as a guard when logging inference, wait times...
|
|
||||||
pub temp_span: Option<Span>,
|
|
||||||
/// Instant when this entry was queued
|
|
||||||
pub queue_time: Instant,
|
|
||||||
/// Instant when this entry was added to a batch
|
|
||||||
pub batch_time: Option<Instant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Request Queue
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct Queue {
|
|
||||||
/// Channel to communicate with the background queue task
|
|
||||||
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Queue {
|
|
||||||
pub(crate) fn new(
|
|
||||||
requires_padding: bool,
|
|
||||||
block_size: u32,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
) -> Self {
|
|
||||||
// Create channel
|
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
|
||||||
|
|
||||||
// Launch background queue task
|
|
||||||
tokio::spawn(queue_task(
|
|
||||||
requires_padding,
|
|
||||||
block_size,
|
|
||||||
window_size,
|
|
||||||
speculate,
|
|
||||||
queue_receiver,
|
|
||||||
));
|
|
||||||
|
|
||||||
Self { queue_sender }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
|
||||||
// Send append command to the background task managing the state
|
|
||||||
// Unwrap is safe here
|
|
||||||
self.queue_sender
|
|
||||||
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the next batch
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
pub(crate) async fn next_batch(
|
|
||||||
&self,
|
|
||||||
min_size: Option<usize>,
|
|
||||||
max_size: Option<usize>,
|
|
||||||
prefill_token_budget: u32,
|
|
||||||
token_budget: u32,
|
|
||||||
) -> Option<NextBatch> {
|
|
||||||
// Create response channel
|
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
|
||||||
// Send next batch command to the background task managing the state
|
|
||||||
// Unwrap is safe here
|
|
||||||
self.queue_sender
|
|
||||||
.send(QueueCommand::NextBatch {
|
|
||||||
min_size,
|
|
||||||
max_size,
|
|
||||||
prefill_token_budget,
|
|
||||||
token_budget,
|
|
||||||
response_sender,
|
|
||||||
span: Span::current(),
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
// Await on response channel
|
|
||||||
// Unwrap is safe here
|
|
||||||
response_receiver.await.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
|
||||||
async fn queue_task(
|
|
||||||
requires_padding: bool,
|
|
||||||
block_size: u32,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
|
||||||
) {
|
|
||||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
|
||||||
match cmd {
|
|
||||||
QueueCommand::Append(entry, span) => {
|
|
||||||
span.in_scope(|| state.append(*entry));
|
|
||||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
|
||||||
}
|
|
||||||
QueueCommand::NextBatch {
|
|
||||||
min_size,
|
|
||||||
max_size,
|
|
||||||
prefill_token_budget,
|
|
||||||
token_budget,
|
|
||||||
response_sender,
|
|
||||||
span,
|
|
||||||
} => span.in_scope(|| {
|
|
||||||
let next_batch =
|
|
||||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
|
||||||
response_sender.send(next_batch).unwrap();
|
|
||||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Queue State
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct State {
|
|
||||||
/// Queue entries organized in a Vec
|
|
||||||
entries: VecDeque<(u64, Entry)>,
|
|
||||||
|
|
||||||
/// Id of the next entry
|
|
||||||
next_id: u64,
|
|
||||||
|
|
||||||
/// Id of the next batch
|
|
||||||
next_batch_id: u64,
|
|
||||||
|
|
||||||
/// Whether the model is using padding
|
|
||||||
requires_padding: bool,
|
|
||||||
|
|
||||||
/// Paged Attention block size
|
|
||||||
block_size: u32,
|
|
||||||
|
|
||||||
/// Sliding window
|
|
||||||
window_size: Option<u32>,
|
|
||||||
|
|
||||||
/// Speculation amount
|
|
||||||
speculate: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl State {
|
|
||||||
fn new(
|
|
||||||
requires_padding: bool,
|
|
||||||
block_size: u32,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
entries: VecDeque::with_capacity(128),
|
|
||||||
next_id: 0,
|
|
||||||
next_batch_id: 0,
|
|
||||||
requires_padding,
|
|
||||||
block_size,
|
|
||||||
window_size,
|
|
||||||
speculate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append an entry to the queue
|
|
||||||
fn append(&mut self, mut entry: Entry) {
|
|
||||||
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
|
||||||
let queue_span = info_span!(parent: &entry.span, "queued");
|
|
||||||
entry.temp_span = Some(queue_span);
|
|
||||||
|
|
||||||
// Push entry in the queue
|
|
||||||
self.entries.push_back((self.next_id, entry));
|
|
||||||
self.next_id += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the next batch
|
|
||||||
fn next_batch(
|
|
||||||
&mut self,
|
|
||||||
min_size: Option<usize>,
|
|
||||||
max_size: Option<usize>,
|
|
||||||
prefill_token_budget: u32,
|
|
||||||
token_budget: u32,
|
|
||||||
) -> Option<NextBatch> {
|
|
||||||
if self.entries.is_empty() {
|
|
||||||
tracing::debug!("No queue");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have enough entries
|
|
||||||
if let Some(min_size) = min_size {
|
|
||||||
if self.entries.len() < min_size {
|
|
||||||
tracing::debug!("Not enough entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_size) = max_size {
|
|
||||||
if max_size == 0 {
|
|
||||||
tracing::debug!("No capacity");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pad prefill_token_budget to be a multiple of block size
|
|
||||||
let prefill_token_budget =
|
|
||||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
|
||||||
next_batch_span.follows_from(&Span::current());
|
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
|
||||||
let mut batch_entries =
|
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
|
||||||
let mut prefill_tokens: u32 = 0;
|
|
||||||
let mut decode_tokens: u32 = 0;
|
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
|
||||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
|
||||||
// was dropped by the client)
|
|
||||||
if entry.response_tx.is_closed() {
|
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
|
||||||
tracing::debug!("Dropping entry");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.requires_padding {
|
|
||||||
// We pad to max input length in the Python shards
|
|
||||||
// We need to take these padding tokens into the equation
|
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
|
||||||
} else {
|
|
||||||
// pad to block size
|
|
||||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
|
||||||
/ self.block_size)
|
|
||||||
* self.block_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.requires_padding {
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
|
||||||
} else {
|
|
||||||
let max_new_tokens = match self.window_size {
|
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
// pad to block size
|
|
||||||
decode_tokens +=
|
|
||||||
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|
||||||
{
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::debug!("Accepting entry");
|
|
||||||
// Create a new span to link the batch back to this entry
|
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
|
||||||
// Add relationships
|
|
||||||
next_batch_span.follows_from(&entry_batch_span);
|
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
|
||||||
// Update entry
|
|
||||||
entry.temp_span = Some(entry_batch_span);
|
|
||||||
|
|
||||||
batch_requests.push(Request {
|
|
||||||
id,
|
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
|
||||||
truncate: entry.request.truncate,
|
|
||||||
parameters: Some(NextTokenChooserParameters::from(
|
|
||||||
entry.request.parameters.clone(),
|
|
||||||
)),
|
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
|
||||||
entry.request.stopping_parameters.clone(),
|
|
||||||
)),
|
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
|
||||||
});
|
|
||||||
// Set batch_time
|
|
||||||
entry.batch_time = Some(Instant::now());
|
|
||||||
// Insert in batch_entries IntMap
|
|
||||||
batch_entries.insert(id, entry);
|
|
||||||
|
|
||||||
// Check if max_size
|
|
||||||
if Some(batch_requests.len()) == max_size {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
tracing::debug!("Filtered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if our batch is big enough
|
|
||||||
if let Some(min_size) = min_size {
|
|
||||||
// Batch is too small
|
|
||||||
if batch_requests.len() < min_size {
|
|
||||||
// Add back entries to the queue in the correct order
|
|
||||||
for r in batch_requests.into_iter().rev() {
|
|
||||||
let id = r.id;
|
|
||||||
let entry = batch_entries.remove(&id).unwrap();
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
}
|
|
||||||
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final batch size
|
|
||||||
let size = batch_requests.len() as u32;
|
|
||||||
next_batch_span.record("batch_size", size);
|
|
||||||
|
|
||||||
let batch = Batch {
|
|
||||||
id: self.next_batch_id,
|
|
||||||
requests: batch_requests,
|
|
||||||
size,
|
|
||||||
max_tokens: (prefill_tokens + decode_tokens),
|
|
||||||
};
|
|
||||||
// Increment batch id
|
|
||||||
self.next_batch_id += 1;
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
|
||||||
|
|
||||||
Some((batch_entries, batch, next_batch_span))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum QueueCommand {
|
|
||||||
Append(Box<Entry>, Span),
|
|
||||||
NextBatch {
|
|
||||||
min_size: Option<usize>,
|
|
||||||
max_size: Option<usize>,
|
|
||||||
prefill_token_budget: u32,
|
|
||||||
token_budget: u32,
|
|
||||||
response_sender: oneshot::Sender<Option<NextBatch>>,
|
|
||||||
span: Span,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ValidParameters> for NextTokenChooserParameters {
|
|
||||||
fn from(value: ValidParameters) -> Self {
|
|
||||||
let (grammar, grammar_type) = match value.grammar {
|
|
||||||
None => (String::new(), GrammarType::None),
|
|
||||||
|
|
||||||
Some(grammar) => match grammar {
|
|
||||||
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
|
||||||
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
temperature: value.temperature,
|
|
||||||
top_k: value.top_k,
|
|
||||||
top_p: value.top_p,
|
|
||||||
typical_p: value.typical_p,
|
|
||||||
do_sample: value.do_sample,
|
|
||||||
seed: value.seed,
|
|
||||||
repetition_penalty: value.repetition_penalty,
|
|
||||||
frequency_penalty: value.frequency_penalty,
|
|
||||||
watermark: value.watermark,
|
|
||||||
grammar,
|
|
||||||
grammar_type: grammar_type.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
|
||||||
fn from(value: ValidStoppingParameters) -> Self {
|
|
||||||
Self {
|
|
||||||
max_new_tokens: value.max_new_tokens,
|
|
||||||
stop_sequences: value.stop_sequences,
|
|
||||||
ignore_eos_token: value.ignore_eos_token,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use tracing::info_span;
|
|
||||||
|
|
||||||
fn default_entry() -> (
|
|
||||||
Entry,
|
|
||||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
|
||||||
) {
|
|
||||||
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
|
||||||
|
|
||||||
let entry = Entry {
|
|
||||||
request: ValidGenerateRequest {
|
|
||||||
inputs: vec![],
|
|
||||||
input_length: 0,
|
|
||||||
truncate: 0,
|
|
||||||
decoder_input_details: false,
|
|
||||||
parameters: ValidParameters {
|
|
||||||
temperature: 0.0,
|
|
||||||
top_k: 0,
|
|
||||||
top_p: 0.0,
|
|
||||||
typical_p: 0.0,
|
|
||||||
do_sample: false,
|
|
||||||
seed: 0,
|
|
||||||
repetition_penalty: 0.0,
|
|
||||||
frequency_penalty: 0.0,
|
|
||||||
watermark: false,
|
|
||||||
grammar: None,
|
|
||||||
},
|
|
||||||
stopping_parameters: ValidStoppingParameters {
|
|
||||||
ignore_eos_token: false,
|
|
||||||
max_new_tokens: 1,
|
|
||||||
stop_sequences: vec![],
|
|
||||||
},
|
|
||||||
top_n_tokens: 0,
|
|
||||||
adapter_id: None,
|
|
||||||
},
|
|
||||||
response_tx,
|
|
||||||
span: info_span!("entry"),
|
|
||||||
temp_span: None,
|
|
||||||
queue_time: Instant::now(),
|
|
||||||
batch_time: None,
|
|
||||||
};
|
|
||||||
(entry, receiver_tx)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_append() {
|
|
||||||
let mut state = State::new(false, 1, None, 0);
|
|
||||||
let (entry, _guard) = default_entry();
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
|
||||||
assert_eq!(state.entries.len(), 0);
|
|
||||||
|
|
||||||
state.append(entry);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 1);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
let (id, _) = state.entries.remove(0).unwrap();
|
|
||||||
assert_eq!(id, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_empty() {
|
|
||||||
let mut state = State::new(false, 1, None, 0);
|
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_min_size() {
|
|
||||||
let mut state = State::new(false, 1, None, 0);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
state.append(entry1);
|
|
||||||
state.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
|
||||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 2);
|
|
||||||
assert_eq!(state.entries.len(), 0);
|
|
||||||
assert_eq!(state.next_batch_id, 1);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
state.append(entry3);
|
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
let (id, _) = state.entries.remove(0).unwrap();
|
|
||||||
assert_eq!(id, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_max_size() {
|
|
||||||
let mut state = State::new(false, 1, None, 0);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
state.append(entry1);
|
|
||||||
state.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
|
||||||
assert_eq!(entries.len(), 1);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 1);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 2);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
assert_eq!(state.next_batch_id, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_batch_token_budget() {
|
|
||||||
let mut state = State::new(false, 1, None, 0);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
state.append(entry1);
|
|
||||||
state.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
|
||||||
assert_eq!(entries.len(), 1);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 1);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 2);
|
|
||||||
assert_eq!(state.entries.len(), 1);
|
|
||||||
assert_eq!(state.next_batch_id, 1);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
state.append(entry3);
|
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.contains_key(&2));
|
|
||||||
assert_eq!(batch.id, 1);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
|
||||||
assert_eq!(state.entries.len(), 0);
|
|
||||||
assert_eq!(state.next_batch_id, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_append() {
|
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
|
||||||
let (entry, _guard) = default_entry();
|
|
||||||
queue.append(entry);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_empty() {
|
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_min_size() {
|
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
queue.append(entry1);
|
|
||||||
queue.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
|
||||||
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
queue.append(entry3);
|
|
||||||
|
|
||||||
// Not enough requests pending
|
|
||||||
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
|
||||||
// Not enough token budget
|
|
||||||
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
|
||||||
// Ok
|
|
||||||
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
|
||||||
assert_eq!(entries2.len(), 1);
|
|
||||||
assert!(entries2.contains_key(&2));
|
|
||||||
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch2.id, 1);
|
|
||||||
assert_eq!(batch2.size, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_max_size() {
|
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
queue.append(entry1);
|
|
||||||
queue.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 1);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_token_budget() {
|
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
queue.append(entry1);
|
|
||||||
queue.append(entry2);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 1);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 1);
|
|
||||||
|
|
||||||
let (entry3, _guard3) = default_entry();
|
|
||||||
queue.append(entry3);
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert!(entries.contains_key(&2));
|
|
||||||
assert_eq!(batch.id, 1);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
|
||||||
let queue = Queue::new(false, 1, None, 2);
|
|
||||||
let (entry1, _guard1) = default_entry();
|
|
||||||
let (entry2, _guard2) = default_entry();
|
|
||||||
queue.append(entry1);
|
|
||||||
queue.append(entry2);
|
|
||||||
|
|
||||||
// Budget of 1 is not enough
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
|
||||||
|
|
||||||
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
|
||||||
assert_eq!(entries.len(), 2);
|
|
||||||
assert!(entries.contains_key(&0));
|
|
||||||
assert!(entries.contains_key(&1));
|
|
||||||
assert_eq!(batch.id, 0);
|
|
||||||
assert_eq!(batch.size, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
|
||||||
let (entry, _) = default_entry();
|
|
||||||
queue.append(entry);
|
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user