mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
wip
This commit is contained in:
parent
9263817c71
commit
6e105c8eb8
61
Cargo.lock
generated
61
Cargo.lock
generated
@ -4175,7 +4175,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4198,7 +4198,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.5.17",
|
"clap 4.5.17",
|
||||||
@ -4219,7 +4219,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
@ -4237,7 +4237,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.5.17",
|
"clap 4.5.17",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -4256,7 +4256,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4303,9 +4303,58 @@ dependencies = [
|
|||||||
"vergen",
|
"vergen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "text-generation-router-v2"
|
||||||
|
version = "2.3.1-dev0"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
|
"async-trait",
|
||||||
|
"axum 0.7.5",
|
||||||
|
"axum-tracing-opentelemetry",
|
||||||
|
"base64 0.22.1",
|
||||||
|
"clap 4.5.17",
|
||||||
|
"futures",
|
||||||
|
"futures-util",
|
||||||
|
"grpc-metadata",
|
||||||
|
"hf-hub",
|
||||||
|
"image",
|
||||||
|
"init-tracing-opentelemetry",
|
||||||
|
"jsonschema",
|
||||||
|
"metrics",
|
||||||
|
"metrics-exporter-prometheus",
|
||||||
|
"minijinja",
|
||||||
|
"minijinja-contrib",
|
||||||
|
"nohash-hasher",
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry 0.20.0",
|
||||||
|
"opentelemetry-otlp",
|
||||||
|
"prost 0.12.6",
|
||||||
|
"prost-build",
|
||||||
|
"rand",
|
||||||
|
"regex",
|
||||||
|
"reqwest",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"slotmap",
|
||||||
|
"text-generation-router",
|
||||||
|
"thiserror",
|
||||||
|
"tokenizers 0.20.0",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tonic 0.10.2",
|
||||||
|
"tonic-build",
|
||||||
|
"tower",
|
||||||
|
"tower-http",
|
||||||
|
"tracing",
|
||||||
|
"tracing-opentelemetry 0.21.0",
|
||||||
|
"tracing-subscriber",
|
||||||
|
"utoipa",
|
||||||
|
"utoipa-swagger-ui",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v3"
|
name = "text-generation-router-v3"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
"backends/trtllm",
|
"backends/trtllm",
|
||||||
"backends/client",
|
|
||||||
"launcher",
|
"launcher",
|
||||||
"router"
|
"router"
|
||||||
]
|
]
|
||||||
default-members = [
|
default-members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
# "backends/trtllm",
|
# "backends/trtllm",
|
||||||
"backends/client",
|
|
||||||
"launcher",
|
"launcher",
|
||||||
"router"
|
"router"
|
||||||
]
|
]
|
||||||
|
75
backends/v2/Cargo.toml
Normal file
75
backends/v2/Cargo.toml
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
[package]
|
||||||
|
name = "text-generation-router-v2"
|
||||||
|
description = "Text Generation Webserver"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "text-generation-router"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait = "0.1.74"
|
||||||
|
async-stream = "0.3.5"
|
||||||
|
axum = { version = "0.7", features = ["json"] }
|
||||||
|
axum-tracing-opentelemetry = "0.16"
|
||||||
|
text-generation-router = { path = "../../router" }
|
||||||
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
|
futures = "0.3.28"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
|
metrics = { workspace = true }
|
||||||
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
|
opentelemetry-otlp = "0.13.0"
|
||||||
|
rand = "0.8.5"
|
||||||
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
|
serde = "1.0.188"
|
||||||
|
serde_json = "1.0.107"
|
||||||
|
slotmap = "1.0.7"
|
||||||
|
thiserror = "1.0.48"
|
||||||
|
tokenizers = { workspace = true }
|
||||||
|
tokio = { version = "1.32.0", features = [
|
||||||
|
"rt",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"parking_lot",
|
||||||
|
"signal",
|
||||||
|
"sync",
|
||||||
|
] }
|
||||||
|
tokio-stream = "0.1.14"
|
||||||
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-opentelemetry = "0.21.0"
|
||||||
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
|
"opentelemetry-otlp",
|
||||||
|
] }
|
||||||
|
minijinja = { workspace = true }
|
||||||
|
minijinja-contrib = { workspace = true }
|
||||||
|
futures-util = "0.3.30"
|
||||||
|
regex = "1.10.3"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
image = "0.25.1"
|
||||||
|
base64 = { workspace = true }
|
||||||
|
prost = "^0.12"
|
||||||
|
tonic = "^0.10"
|
||||||
|
tower = "^0.4"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-build = "0.10.1"
|
||||||
|
prost-build = "0.12.1"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["ngrok"]
|
||||||
|
ngrok = ["text-generation-router/ngrok"]
|
||||||
|
google = ["text-generation-router/google"]
|
||||||
|
kserve = ["text-generation-router/kserve"]
|
19
backends/v2/build.rs
Normal file
19
backends/v2/build.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
use std::fs;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("cargo:rerun-if-changed=../../proto/");
|
||||||
|
|
||||||
|
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||||
|
let mut config = prost_build::Config::new();
|
||||||
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_client(true)
|
||||||
|
.build_server(false)
|
||||||
|
.out_dir("src/client/pb")
|
||||||
|
.include_file("mod.rs")
|
||||||
|
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||||
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
514
backends/v2/src/backend.rs
Normal file
514
backends/v2/src/backend.rs
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||||
|
/// Batching and inference logic
|
||||||
|
use crate::queue::{Entry, Queue};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::{mpsc, Notify};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
pub struct BackendV3 {
|
||||||
|
/// Request queue
|
||||||
|
queue: Queue,
|
||||||
|
/// Notify batcher on queue appends
|
||||||
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
/// Client clone, used for health checks to skip the queue
|
||||||
|
client: ShardedClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendV3 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
|
let prefix_caching =
|
||||||
|
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||||
|
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||||
|
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||||
|
|
||||||
|
let attention: Attention = attention
|
||||||
|
.parse()
|
||||||
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||||
|
let block_size = attention.block_size();
|
||||||
|
|
||||||
|
let queue = Queue::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
client.clone(),
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
queue.clone(),
|
||||||
|
batching_task_notifier.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
queue,
|
||||||
|
batching_task_notifier,
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for BackendV3 {
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Append the request to the queue
|
||||||
|
self.queue.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx,
|
||||||
|
span: Span::current(),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
// to be batched
|
||||||
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
if current_health {
|
||||||
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
|
self.client.device_health().await
|
||||||
|
} else {
|
||||||
|
self.client.model_health().await
|
||||||
|
}
|
||||||
|
.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batching logic
|
||||||
|
/// Will be launched in a background Tokio task
|
||||||
|
///
|
||||||
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
queue: Queue,
|
||||||
|
notifier: Arc<Notify>,
|
||||||
|
) {
|
||||||
|
// Infinite loop
|
||||||
|
loop {
|
||||||
|
// Wait for a notification from the Infer struct
|
||||||
|
notifier.notified().await;
|
||||||
|
|
||||||
|
// Get the next batch from the queue
|
||||||
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
|
// waiting in the queue
|
||||||
|
while let Some((mut entries, batch, span)) = queue
|
||||||
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
|
while let Some(batch) = cached_batch {
|
||||||
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Minimum batch size
|
||||||
|
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||||
|
// reallocation when using prefix caching.
|
||||||
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size =
|
||||||
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
|
// Try to get a new batch
|
||||||
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
// Tracking metrics
|
||||||
|
if min_size.is_some() {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
|
} else {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||||
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
entries.extend(new_entries);
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_size = entries.len();
|
||||||
|
let next_batch_span =
|
||||||
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
|
.instrument(next_batch_span)
|
||||||
|
.await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn prefill(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_id = batch.id;
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
for id in batch_ids {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn filter_batch(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
next_batch: Option<CachedBatch>,
|
||||||
|
entries: &IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let mut batch = next_batch?;
|
||||||
|
|
||||||
|
// No need to filter
|
||||||
|
if batch.size as usize == entries.len() {
|
||||||
|
return Some(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = batch.id;
|
||||||
|
|
||||||
|
// Retain only requests that are still in entries
|
||||||
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||||
|
|
||||||
|
if batch.request_ids.is_empty() {
|
||||||
|
// All requests have been filtered out
|
||||||
|
// Next batch is now empty
|
||||||
|
// Clear it from the Python shards cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Filter Python shard cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
|
/// and filter entries
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
generations.into_iter().for_each(|generation| {
|
||||||
|
let id = generation.request_id;
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.get(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
// Send generation responses back to the infer task
|
||||||
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
|
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
}).unwrap_or(true);
|
||||||
|
if stopped {
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send responses through the `entry` response channel
|
||||||
|
fn send_responses(
|
||||||
|
generation: Generation,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
// Return directly if the channel is disconnected
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stopped = false;
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs)
|
||||||
|
.zip(prefill_tokens.texts)
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
|
let n = tokens_.ids.len();
|
||||||
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
|
let mut iterator = tokens_
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(tokens_.logprobs)
|
||||||
|
.zip(tokens_.texts)
|
||||||
|
.zip(tokens_.is_special)
|
||||||
|
.enumerate()
|
||||||
|
.peekable();
|
||||||
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
let token = Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
|
top_tokens_
|
||||||
|
.ids
|
||||||
|
.iter()
|
||||||
|
.zip(top_tokens_.logprobs.iter())
|
||||||
|
.zip(top_tokens_.texts.iter())
|
||||||
|
.zip(top_tokens_.is_special.iter())
|
||||||
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
|
id,
|
||||||
|
text: text.to_string(),
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
|
(Some(generated_text), None) => {
|
||||||
|
// Generation has ended
|
||||||
|
stopped = true;
|
||||||
|
// Send message
|
||||||
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(stopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
|
tracing::error!("{err}");
|
||||||
|
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(err))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||||
|
fn from(value: crate::client::GeneratedText) -> Self {
|
||||||
|
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
|
let finish_reason = match v3_finish_reason {
|
||||||
|
crate::client::FinishReason::Length => FinishReason::Length,
|
||||||
|
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text: value.text,
|
||||||
|
generated_tokens: value.generated_tokens,
|
||||||
|
finish_reason,
|
||||||
|
seed: value.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
209
backends/v2/src/block_allocator.rs
Normal file
209
backends/v2/src/block_allocator.rs
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
use crate::radix::RadixAllocator;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BlockAllocation {
|
||||||
|
pub allocation_id: u64,
|
||||||
|
pub blocks: Vec<u32>,
|
||||||
|
pub slots: Vec<u32>,
|
||||||
|
|
||||||
|
/// Prefix that was cached and for which the KV does not have to
|
||||||
|
/// be recomputed.
|
||||||
|
pub prefix_len: u32,
|
||||||
|
|
||||||
|
pub(crate) block_allocator: Option<BlockAllocator>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for BlockAllocation {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||||
|
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BlockAllocator {
|
||||||
|
/// Channel to communicate with the background task
|
||||||
|
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BlockAllocator {
|
||||||
|
pub(crate) fn new(
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
) -> Self {
|
||||||
|
// Create channel
|
||||||
|
let (sender, receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Launch background queue task
|
||||||
|
tokio::spawn(block_allocator_task(
|
||||||
|
max_batch_total_tokens / block_size,
|
||||||
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
|
window_size,
|
||||||
|
receiver,
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
block_allocator: sender,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn allocate(
|
||||||
|
&self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
|
tokens,
|
||||||
|
prefill_tokens,
|
||||||
|
response_sender,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
response_receiver.await.unwrap().map(|mut allocation| {
|
||||||
|
allocation.block_allocator = Some(self.clone());
|
||||||
|
allocation
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Free {
|
||||||
|
allocation_id,
|
||||||
|
blocks,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn block_allocator_task(
|
||||||
|
blocks: u32,
|
||||||
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||||
|
) {
|
||||||
|
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||||
|
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||||
|
} else {
|
||||||
|
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||||
|
};
|
||||||
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
match cmd {
|
||||||
|
BlockAllocatorCommand::Free {
|
||||||
|
blocks,
|
||||||
|
allocation_id,
|
||||||
|
} => allocator.free(blocks, allocation_id),
|
||||||
|
BlockAllocatorCommand::Allocate {
|
||||||
|
tokens,
|
||||||
|
prefill_tokens,
|
||||||
|
response_sender,
|
||||||
|
} => {
|
||||||
|
response_sender
|
||||||
|
.send(allocator.allocate(tokens, prefill_tokens))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum BlockAllocatorCommand {
|
||||||
|
Free {
|
||||||
|
blocks: Vec<u32>,
|
||||||
|
allocation_id: u64,
|
||||||
|
},
|
||||||
|
Allocate {
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Allocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation>;
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||||
|
}
|
||||||
|
pub struct SimpleAllocator {
|
||||||
|
free_blocks: Vec<u32>,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimpleAllocator {
|
||||||
|
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
|
SimpleAllocator {
|
||||||
|
block_size,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
free_blocks: (1..blocks).collect(),
|
||||||
|
window_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Allocator for SimpleAllocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
|
// Apply window size
|
||||||
|
let (required_blocks, repeats) = {
|
||||||
|
let (tokens, repeats) = match self.window_size {
|
||||||
|
None => (tokens, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
|
let tokens = core::cmp::min(tokens, window_size);
|
||||||
|
(tokens, repeats as usize)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Pad to a multiple of block size
|
||||||
|
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||||
|
(required_blocks, repeats)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokens = tokens as usize;
|
||||||
|
if required_blocks > self.free_blocks.len() as u32 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let blocks = self
|
||||||
|
.free_blocks
|
||||||
|
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||||
|
let mut slots =
|
||||||
|
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||||
|
|
||||||
|
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(BlockAllocation {
|
||||||
|
allocation_id: 0,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
prefix_len: 0,
|
||||||
|
block_allocator: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||||
|
self.free_blocks.extend(blocks)
|
||||||
|
}
|
||||||
|
}
|
286
backends/v2/src/client/grpc_client.rs
Normal file
286
backends/v2/src/client/grpc_client.rs
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
/// Single shard Client
|
||||||
|
use crate::client::{pb, Chunk};
|
||||||
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
|
use base64::engine::general_purpose::STANDARD;
|
||||||
|
use base64::Engine;
|
||||||
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||||
|
use pb::generate::v3::*;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::transport::{Channel, Uri};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Text Generation Inference gRPC client
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Client {
|
||||||
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Returns a client connected to the given url
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||||
|
.unwrap()
|
||||||
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||||
|
tokio::net::UnixStream::connect(path.clone())
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a list of uris or unix sockets of all shards
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
|
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||||
|
})?;
|
||||||
|
let urls = response
|
||||||
|
.into_inner()
|
||||||
|
.urls
|
||||||
|
.into_iter()
|
||||||
|
// Remove unix socket prefix
|
||||||
|
.map(|url| match url.strip_prefix("unix://") {
|
||||||
|
None => url,
|
||||||
|
Some(stripped_url) => stripped_url.to_string(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(urls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||||
|
let response = self.stub.info(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||||
|
self.stub.clear_cache(request).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
|
batch_id,
|
||||||
|
request_ids,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
Ok(filtered_batch.batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
|
let mut input_chunks = Vec::new();
|
||||||
|
input_chunks
|
||||||
|
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||||
|
if n_tokens == 0 {
|
||||||
|
input_chunks.push(
|
||||||
|
Chunk::Image(Image {
|
||||||
|
// Safe unwrap, because we control the data.
|
||||||
|
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||||
|
mimetype: "image/jpeg;base64".to_string(),
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||||
|
// been updated to support chunks.
|
||||||
|
|
||||||
|
let mut inputs = String::new();
|
||||||
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str(&format!(
|
||||||
|
"",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
inputs,
|
||||||
|
add_special_tokens: true,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: input_chunks,
|
||||||
|
}),
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
truncate,
|
||||||
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
|
blocks: vec![],
|
||||||
|
slots: vec![],
|
||||||
|
prefix_len: 0,
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
frequency_penalty: 0.1,
|
||||||
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: true,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
|
adapter_id: None,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
// Check max_batch_size
|
||||||
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: max_input_length,
|
||||||
|
max_blocks: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(response.max_supported_total_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PrefillTimings {
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefillTimings {
|
||||||
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DecodeTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodeTimings {
|
||||||
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
76
backends/v2/src/client/mod.rs
Normal file
76
backends/v2/src/client/mod.rs
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tonic::transport;
|
||||||
|
use tonic::Status;
|
||||||
|
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod grpc_client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use grpc_client::Client;
|
||||||
|
pub use pb::generate::v3::{
|
||||||
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Health {
|
||||||
|
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||||
|
async fn device_health(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Check if a generate server is healthy by doing a forward pass.
|
||||||
|
/// EXPENSIVE
|
||||||
|
async fn model_health(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShardInfo {
|
||||||
|
pub requires_padding: bool,
|
||||||
|
pub dtype: String,
|
||||||
|
pub device_type: String,
|
||||||
|
pub window_size: Option<u32>,
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug, Clone)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
Connection(String),
|
||||||
|
#[error("Server error: {0}")]
|
||||||
|
Generation(String),
|
||||||
|
#[error("Sharded results are empty")]
|
||||||
|
EmptyResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Status> for ClientError {
|
||||||
|
fn from(err: Status) -> Self {
|
||||||
|
let err = Self::Generation(err.message().to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<transport::Error> for ClientError {
|
||||||
|
fn from(err: transport::Error) -> Self {
|
||||||
|
let err = Self::Connection(err.to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Small convenience re-wrapping of `Chunk`.
|
||||||
|
impl From<Chunk> for InputChunk {
|
||||||
|
fn from(chunk: Chunk) -> Self {
|
||||||
|
InputChunk { chunk: Some(chunk) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
262
backends/v2/src/client/sharded_client.rs
Normal file
262
backends/v2/src/client/sharded_client.rs
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
use crate::client::{ClientError, Result};
|
||||||
|
/// Multi shard Client
|
||||||
|
use crate::client::{Health, ShardInfo};
|
||||||
|
|
||||||
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use crate::client::{Chunk, InfoResponse, Input};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
|
pub struct ShardedClient {
|
||||||
|
clients: Vec<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShardedClient {
|
||||||
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
Self { clients }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
|
let uris = master_client.service_discovery().await?;
|
||||||
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
|
Ok(Self::new(clients?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given uri
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let master_client = Client::connect(uri).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let master_client = Client::connect_uds(path).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache(batch_id))
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Take the minimum value
|
||||||
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
Ok(results.into_iter().flatten().min())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
|
}),
|
||||||
|
truncate: 10,
|
||||||
|
add_special_tokens: true,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
blocks: vec![0],
|
||||||
|
slots: (0..16).collect(),
|
||||||
|
prefix_len: 0,
|
||||||
|
adapter_id: None,
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
max_blocks: 1,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
143
backends/v2/src/lib.rs
Normal file
143
backends/v2/src/lib.rs
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
mod backend;
|
||||||
|
pub mod block_allocator;
|
||||||
|
mod client;
|
||||||
|
mod queue;
|
||||||
|
pub mod radix;
|
||||||
|
|
||||||
|
use crate::client::{ClientError, ShardedClient};
|
||||||
|
pub(crate) use backend::BackendV3;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct BackendInfo {
|
||||||
|
/// Mandatory
|
||||||
|
#[schema(example = "cuda")]
|
||||||
|
pub model_device_type: String,
|
||||||
|
#[schema(example = "torch.float16")]
|
||||||
|
pub model_dtype: String,
|
||||||
|
|
||||||
|
/// Backend parameters
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub speculate: usize,
|
||||||
|
#[schema(example = "1.2")]
|
||||||
|
pub waiting_served_ratio: f32,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_batch_total_tokens: u32,
|
||||||
|
#[schema(example = "20")]
|
||||||
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub max_batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn connect_backend(
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||||
|
// Helper function
|
||||||
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
|
match max_supported_batch_total_tokens {
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
Ok(max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Connection)?;
|
||||||
|
|
||||||
|
// server is running on v3
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(V3Error::Warmup)?,
|
||||||
|
)?;
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
|
let backend_info = BackendInfo {
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
model_dtype: shard_info.dtype.clone(),
|
||||||
|
speculate: shard_info.speculate as usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
let backend = BackendV3::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!("Using backend V3");
|
||||||
|
|
||||||
|
Ok((backend, backend_info))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum V3Error {
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
|
NotEnoughMemory(usize),
|
||||||
|
}
|
212
backends/v2/src/main.rs
Normal file
212
backends/v2/src/main.rs
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use text_generation_router::{server, usage_stats};
|
||||||
|
use text_generation_router_v3::{connect_backend, V3Error};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), RouterError> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
command,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
master_shard_uds_path,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||||
|
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
|
println!("{}", api_doc);
|
||||||
|
std::process::exit(0);
|
||||||
|
};
|
||||||
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
if max_batch_size == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_batch_size` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (backend, _backend_info) = connect_backend(
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
master_shard_uds_path,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("Backend failed: {0}")]
|
||||||
|
Backend(#[from] V3Error),
|
||||||
|
#[error("WebServer error: {0}")]
|
||||||
|
WebServer(#[from] server::WebServerError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
}
|
793
backends/v2/src/queue.rs
Normal file
793
backends/v2/src/queue.rs
Normal file
@ -0,0 +1,793 @@
|
|||||||
|
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||||
|
use crate::client;
|
||||||
|
use crate::client::{
|
||||||
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
|
use std::cmp::{max, min};
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use text_generation_router::infer::InferError;
|
||||||
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::validation::{
|
||||||
|
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||||
|
ValidStoppingParameters,
|
||||||
|
};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
/// Queue entry
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Entry {
|
||||||
|
/// Request
|
||||||
|
pub request: ValidGenerateRequest,
|
||||||
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
|
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
|
/// Span that will live as long as entry
|
||||||
|
pub span: Span,
|
||||||
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
|
pub temp_span: Option<Span>,
|
||||||
|
/// Instant when this entry was queued
|
||||||
|
pub queue_time: Instant,
|
||||||
|
/// Instant when this entry was added to a batch
|
||||||
|
pub batch_time: Option<Instant>,
|
||||||
|
/// Block Allocation
|
||||||
|
pub block_allocation: Option<BlockAllocation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request Queue
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Queue {
|
||||||
|
/// Channel to communicate with the background queue task
|
||||||
|
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Queue {
|
||||||
|
pub(crate) fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
) -> Self {
|
||||||
|
// Create channel
|
||||||
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Launch background queue task
|
||||||
|
tokio::spawn(queue_task(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
queue_receiver,
|
||||||
|
));
|
||||||
|
|
||||||
|
Self { queue_sender }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append an entry to the queue
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
|
// Send append command to the background task managing the state
|
||||||
|
// Unwrap is safe here
|
||||||
|
self.queue_sender
|
||||||
|
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub(crate) async fn next_batch(
|
||||||
|
&self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
) -> Option<NextBatch> {
|
||||||
|
// Create response channel
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
// Send next batch command to the background task managing the state
|
||||||
|
// Unwrap is safe here
|
||||||
|
self.queue_sender
|
||||||
|
.send(QueueCommand::NextBatch {
|
||||||
|
min_size,
|
||||||
|
max_size,
|
||||||
|
prefill_token_budget,
|
||||||
|
token_budget,
|
||||||
|
response_sender,
|
||||||
|
span: Span::current(),
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
// Await on response channel
|
||||||
|
// Unwrap is safe here
|
||||||
|
response_receiver.await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Background task responsible of the queue state
|
||||||
|
async fn queue_task(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
|
) {
|
||||||
|
let mut state = State::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
match cmd {
|
||||||
|
QueueCommand::Append(entry, span) => {
|
||||||
|
span.in_scope(|| state.append(*entry));
|
||||||
|
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||||
|
}
|
||||||
|
QueueCommand::NextBatch {
|
||||||
|
min_size,
|
||||||
|
max_size,
|
||||||
|
prefill_token_budget,
|
||||||
|
token_budget,
|
||||||
|
response_sender,
|
||||||
|
span,
|
||||||
|
} => {
|
||||||
|
let next_batch = state
|
||||||
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
response_sender.send(next_batch).unwrap();
|
||||||
|
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Queue State
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct State {
|
||||||
|
/// Queue entries organized in a Vec
|
||||||
|
entries: VecDeque<(u64, Entry)>,
|
||||||
|
|
||||||
|
/// Id of the next entry
|
||||||
|
next_id: u64,
|
||||||
|
|
||||||
|
/// Id of the next batch
|
||||||
|
next_batch_id: u64,
|
||||||
|
|
||||||
|
/// Paged Attention block size
|
||||||
|
block_size: u32,
|
||||||
|
|
||||||
|
/// Sliding window
|
||||||
|
window_size: Option<u32>,
|
||||||
|
|
||||||
|
/// Speculation amount
|
||||||
|
speculate: u32,
|
||||||
|
|
||||||
|
/// Paged Attention Block Allocation
|
||||||
|
block_allocator: Option<BlockAllocator>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn new(
|
||||||
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
prefix_caching: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
) -> Self {
|
||||||
|
let block_allocator = (!requires_padding).then(|| {
|
||||||
|
BlockAllocator::new(
|
||||||
|
max_batch_total_tokens,
|
||||||
|
block_size,
|
||||||
|
prefix_caching,
|
||||||
|
window_size,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
entries: VecDeque::with_capacity(128),
|
||||||
|
next_id: 0,
|
||||||
|
next_batch_id: 0,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
block_allocator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append an entry to the queue
|
||||||
|
fn append(&mut self, mut entry: Entry) {
|
||||||
|
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||||
|
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||||
|
entry.temp_span = Some(queue_span);
|
||||||
|
|
||||||
|
// Push entry in the queue
|
||||||
|
self.entries.push_back((self.next_id, entry));
|
||||||
|
self.next_id += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next batch
|
||||||
|
async fn next_batch(
|
||||||
|
&mut self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
) -> Option<NextBatch> {
|
||||||
|
if self.entries.is_empty() {
|
||||||
|
tracing::debug!("No queue");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have enough entries
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
if self.entries.len() < min_size {
|
||||||
|
tracing::debug!("Not enough entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_size) = max_size {
|
||||||
|
if max_size == 0 {
|
||||||
|
tracing::debug!("No capacity");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pad prefill_token_budget to be a multiple of block size
|
||||||
|
let prefill_token_budget =
|
||||||
|
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
|
next_batch_span.follows_from(Span::current());
|
||||||
|
|
||||||
|
let mut batch = Vec::with_capacity(self.entries.len());
|
||||||
|
let mut max_input_length = 0;
|
||||||
|
let mut prefill_tokens: u32 = 0;
|
||||||
|
let mut decode_tokens: u32 = 0;
|
||||||
|
let mut max_blocks = 0;
|
||||||
|
|
||||||
|
// Pop entries starting from the front of the queue
|
||||||
|
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||||
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
|
// was dropped by the client)
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
tracing::debug!("Dropping entry");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let block_allocation = match &self.block_allocator {
|
||||||
|
None => {
|
||||||
|
// We pad to max input length in the Python shards
|
||||||
|
// We need to take these padding tokens into the equation
|
||||||
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
|
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||||
|
|
||||||
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
|
|
||||||
|
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Some(_block_allocator) => {
|
||||||
|
prefill_tokens += entry.request.input_length;
|
||||||
|
let max_new_tokens = match self.window_size {
|
||||||
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
Some(window_size) => min(
|
||||||
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
decode_tokens += max_new_tokens;
|
||||||
|
|
||||||
|
if prefill_tokens > prefill_token_budget
|
||||||
|
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||||
|
{
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tokens = entry.request.input_length
|
||||||
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
|
+ self.speculate
|
||||||
|
- 1;
|
||||||
|
|
||||||
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
|
// So no input_ids for the radix tree.
|
||||||
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
entry.request.input_ids.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
Some((tokens, input_ids))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
batch.push((id, entry, block_allocation));
|
||||||
|
if Some(batch.len()) == max_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||||
|
// Check if our batch is big enough
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
// Batch is too small
|
||||||
|
if batch.len() < min_size {
|
||||||
|
// Add back entries to the queue in the correct order
|
||||||
|
for (id, entry, _) in batch.into_iter().rev() {
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
|
let mut batch_entries =
|
||||||
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
|
for (id, mut entry, block_allocation) in batch {
|
||||||
|
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||||
|
(block_allocation, &self.block_allocator)
|
||||||
|
{
|
||||||
|
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||||
|
match block_allocator.allocate(tokens, input_ids).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Some(block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
tracing::debug!("Accepting entry");
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
|
let (blocks, slots, prefix_len) = match &block_allocation {
|
||||||
|
None => (Vec::new(), Vec::new(), 0),
|
||||||
|
Some(block_allocation) => (
|
||||||
|
block_allocation.blocks.clone(),
|
||||||
|
block_allocation.slots.clone(),
|
||||||
|
block_allocation.prefix_len,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
entry.block_allocation = block_allocation;
|
||||||
|
|
||||||
|
batch_requests.push(Request {
|
||||||
|
id,
|
||||||
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
|
input_chunks: Some(client::Input {
|
||||||
|
chunks: entry
|
||||||
|
.request
|
||||||
|
.inputs
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| client::InputChunk {
|
||||||
|
chunk: Some(match c {
|
||||||
|
Chunk::Text(text) => client::Chunk::Text(text),
|
||||||
|
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||||
|
data: image.data,
|
||||||
|
mimetype: image.mimetype,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
}),
|
||||||
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
|
truncate: entry.request.truncate,
|
||||||
|
add_special_tokens: entry.request.add_special_tokens,
|
||||||
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||||
|
entry.request.stopping_parameters.clone(),
|
||||||
|
)),
|
||||||
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
prefix_len,
|
||||||
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
|
});
|
||||||
|
// Set batch_time
|
||||||
|
entry.batch_time = Some(Instant::now());
|
||||||
|
// Insert in batch_entries IntMap
|
||||||
|
batch_entries.insert(id, entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch_requests.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final batch size
|
||||||
|
let size = batch_requests.len() as u32;
|
||||||
|
next_batch_span.record("batch_size", size);
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: self.next_batch_id,
|
||||||
|
requests: batch_requests,
|
||||||
|
size,
|
||||||
|
max_tokens: (prefill_tokens + decode_tokens),
|
||||||
|
max_blocks,
|
||||||
|
};
|
||||||
|
// Increment batch id
|
||||||
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||||
|
|
||||||
|
Some((batch_entries, batch, next_batch_span))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum QueueCommand {
|
||||||
|
Append(Box<Entry>, Span),
|
||||||
|
NextBatch {
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: Option<usize>,
|
||||||
|
prefill_token_budget: u32,
|
||||||
|
token_budget: u32,
|
||||||
|
response_sender: oneshot::Sender<Option<NextBatch>>,
|
||||||
|
span: Span,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||||
|
fn from(value: ValidParameters) -> Self {
|
||||||
|
let (grammar, grammar_type) = match value.grammar {
|
||||||
|
None => (String::new(), GrammarType::None),
|
||||||
|
|
||||||
|
Some(grammar) => match grammar {
|
||||||
|
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||||
|
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
temperature: value.temperature,
|
||||||
|
top_k: value.top_k,
|
||||||
|
top_p: value.top_p,
|
||||||
|
typical_p: value.typical_p,
|
||||||
|
do_sample: value.do_sample,
|
||||||
|
seed: value.seed,
|
||||||
|
repetition_penalty: value.repetition_penalty,
|
||||||
|
frequency_penalty: value.frequency_penalty,
|
||||||
|
watermark: value.watermark,
|
||||||
|
grammar,
|
||||||
|
grammar_type: grammar_type.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||||
|
fn from(value: ValidStoppingParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
max_new_tokens: value.max_new_tokens,
|
||||||
|
stop_sequences: value.stop_sequences,
|
||||||
|
ignore_eos_token: value.ignore_eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use tracing::info_span;
|
||||||
|
|
||||||
|
fn default_entry() -> (
|
||||||
|
Entry,
|
||||||
|
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||||
|
) {
|
||||||
|
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
let entry = Entry {
|
||||||
|
request: ValidGenerateRequest {
|
||||||
|
inputs: vec![],
|
||||||
|
input_ids: Some(Arc::new(vec![])),
|
||||||
|
input_length: 0,
|
||||||
|
add_special_tokens: true,
|
||||||
|
truncate: 0,
|
||||||
|
decoder_input_details: false,
|
||||||
|
parameters: ValidParameters {
|
||||||
|
temperature: 0.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 0.0,
|
||||||
|
typical_p: 0.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 0.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: None,
|
||||||
|
},
|
||||||
|
stopping_parameters: ValidStoppingParameters {
|
||||||
|
ignore_eos_token: false,
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
},
|
||||||
|
top_n_tokens: 0,
|
||||||
|
adapter_id: None,
|
||||||
|
},
|
||||||
|
response_tx,
|
||||||
|
span: info_span!("entry"),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
|
};
|
||||||
|
(entry, receiver_tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_append() {
|
||||||
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 0);
|
||||||
|
assert_eq!(state.entries.len(), 0);
|
||||||
|
|
||||||
|
state.append(entry);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 1);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
let (id, _) = state.entries.remove(0).unwrap();
|
||||||
|
assert_eq!(id, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_empty() {
|
||||||
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
|
|
||||||
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_min_size() {
|
||||||
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 0);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
state.append(entry3);
|
||||||
|
|
||||||
|
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 3);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
let (id, _) = state.entries.remove(0).unwrap();
|
||||||
|
assert_eq!(id, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_max_size() {
|
||||||
|
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_next_batch_token_budget() {
|
||||||
|
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
state.append(entry1);
|
||||||
|
state.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 2);
|
||||||
|
assert_eq!(state.entries.len(), 1);
|
||||||
|
assert_eq!(state.next_batch_id, 1);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
state.append(entry3);
|
||||||
|
|
||||||
|
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.contains_key(&2));
|
||||||
|
assert_eq!(batch.id, 1);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
|
assert_eq!(state.next_id, 3);
|
||||||
|
assert_eq!(state.entries.len(), 0);
|
||||||
|
assert_eq!(state.next_batch_id, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_append() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry, _guard) = default_entry();
|
||||||
|
queue.append(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_empty() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
|
||||||
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_min_size() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert!(entries.get(&1).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
queue.append(entry3);
|
||||||
|
|
||||||
|
// Not enough requests pending
|
||||||
|
assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
|
// Not enough token budget
|
||||||
|
assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
|
||||||
|
// Ok
|
||||||
|
let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries2.len(), 1);
|
||||||
|
assert!(entries2.contains_key(&2));
|
||||||
|
assert!(entries2.get(&2).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch2.id, 1);
|
||||||
|
assert_eq!(batch2.size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_max_size() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_token_budget() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 1);
|
||||||
|
|
||||||
|
let (entry3, _guard3) = default_entry();
|
||||||
|
queue.append(entry3);
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert!(entries.contains_key(&2));
|
||||||
|
assert_eq!(batch.id, 1);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||||
|
let (entry1, _guard1) = default_entry();
|
||||||
|
let (entry2, _guard2) = default_entry();
|
||||||
|
queue.append(entry1);
|
||||||
|
queue.append(entry2);
|
||||||
|
|
||||||
|
// Budget of 1 is not enough
|
||||||
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
|
||||||
|
let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
|
||||||
|
assert_eq!(entries.len(), 2);
|
||||||
|
assert!(entries.contains_key(&0));
|
||||||
|
assert!(entries.contains_key(&1));
|
||||||
|
assert_eq!(batch.id, 0);
|
||||||
|
assert_eq!(batch.size, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
|
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||||
|
let (entry, _) = default_entry();
|
||||||
|
queue.append(entry);
|
||||||
|
|
||||||
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
|
}
|
||||||
|
}
|
876
backends/v2/src/radix.rs
Normal file
876
backends/v2/src/radix.rs
Normal file
@ -0,0 +1,876 @@
|
|||||||
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::{
|
||||||
|
collections::{BTreeSet, HashMap},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn hash(slice: &[u32]) -> u64 {
|
||||||
|
assert!(!slice.is_empty());
|
||||||
|
if slice.len() == 1 {
|
||||||
|
slice[0] as u64
|
||||||
|
} else {
|
||||||
|
let mut s = std::hash::DefaultHasher::new();
|
||||||
|
slice.hash(&mut s);
|
||||||
|
s.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RadixAllocator {
|
||||||
|
allocation_id: u64,
|
||||||
|
|
||||||
|
allocations: HashMap<u64, RadixAllocation>,
|
||||||
|
|
||||||
|
cache_blocks: RadixTrie,
|
||||||
|
|
||||||
|
/// Blocks that are immediately available for allocation.
|
||||||
|
free_blocks: Vec<u32>,
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
// This isn't used because the prefix need to match without the windowing
|
||||||
|
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||||
|
window_size: Option<u32>,
|
||||||
|
|
||||||
|
block_size: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RadixAllocator {
|
||||||
|
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||||
|
RadixAllocator {
|
||||||
|
allocation_id: 0,
|
||||||
|
allocations: HashMap::new(),
|
||||||
|
cache_blocks: RadixTrie::new(block_size as usize),
|
||||||
|
|
||||||
|
// Block 0 is reserved for health checks.
|
||||||
|
free_blocks: (1..n_blocks).collect(),
|
||||||
|
window_size,
|
||||||
|
block_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
||||||
|
if self.free_blocks.len() < n_blocks_needed {
|
||||||
|
// This is a bit annoying, we first extend the free list and then
|
||||||
|
// split it off again below. This is because we need to put it on
|
||||||
|
// the free list if we cannot allocate enough blocks. This is only
|
||||||
|
// temporary, the trie needs to be able to report whether it can
|
||||||
|
// allocate the requested amount. Just not implemented yet.
|
||||||
|
tracing::debug!(
|
||||||
|
"Free blocks {} need {n_blocks_needed}",
|
||||||
|
self.free_blocks.len()
|
||||||
|
);
|
||||||
|
self.free_blocks.extend(
|
||||||
|
self.cache_blocks
|
||||||
|
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.free_blocks.len() >= n_blocks_needed {
|
||||||
|
Some(
|
||||||
|
self.free_blocks
|
||||||
|
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocator trait
|
||||||
|
impl Allocator for RadixAllocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
|
let mut blocks = vec![];
|
||||||
|
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||||
|
let node_id = self
|
||||||
|
.cache_blocks
|
||||||
|
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||||
|
node_id
|
||||||
|
} else {
|
||||||
|
self.cache_blocks.root_id()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Even if this allocation fails below, we need to increase he
|
||||||
|
// refcount to ensure that the prefix that was found is not evicted.
|
||||||
|
self.cache_blocks
|
||||||
|
.incref(prefix_node)
|
||||||
|
.expect("Failed to increment refcount");
|
||||||
|
|
||||||
|
let prefix_len = blocks.len() * self.block_size as usize;
|
||||||
|
let suffix_len = tokens - prefix_len as u32;
|
||||||
|
|
||||||
|
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||||
|
|
||||||
|
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||||
|
|
||||||
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
|
None => {
|
||||||
|
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
|
||||||
|
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
|
||||||
|
tracing::debug!("Block size {}", self.block_size);
|
||||||
|
self.cache_blocks
|
||||||
|
.decref(prefix_node)
|
||||||
|
.expect("Failed to decrement refcount");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1:1 mapping of blocks and slots.
|
||||||
|
let slots = if self.block_size == 1 {
|
||||||
|
blocks.clone()
|
||||||
|
} else {
|
||||||
|
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
||||||
|
'slots: for block_id in &blocks {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() as u32 == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slots
|
||||||
|
};
|
||||||
|
|
||||||
|
let allocation = RadixAllocation {
|
||||||
|
prefix_node,
|
||||||
|
cached_prefix_len: prefix_len,
|
||||||
|
prefill_tokens: prefill_tokens.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.allocation_id += 1;
|
||||||
|
self.allocations.insert(self.allocation_id, allocation);
|
||||||
|
|
||||||
|
Some(BlockAllocation {
|
||||||
|
allocation_id: self.allocation_id,
|
||||||
|
block_allocator: None,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
prefix_len: prefix_len as u32,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
|
let allocation = match self.allocations.remove(&allocation_id) {
|
||||||
|
Some(allocation) => allocation,
|
||||||
|
None => unreachable!("Tried to free an unknown allocation."),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.cache_blocks
|
||||||
|
.decref(allocation.prefix_node)
|
||||||
|
.expect("Failed to decrement refcount");
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||||
|
let prefill_tokens = prefill_tokens.as_slice();
|
||||||
|
|
||||||
|
// If there are prefill tokens that did not come from the cache,
|
||||||
|
// add them to the cache.
|
||||||
|
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||||
|
let aligned =
|
||||||
|
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||||
|
if aligned > 0 {
|
||||||
|
let prefix_len = self
|
||||||
|
.cache_blocks
|
||||||
|
.insert(
|
||||||
|
&prefill_tokens[..aligned],
|
||||||
|
&blocks[..aligned / self.block_size as usize],
|
||||||
|
)
|
||||||
|
// Unwrap, failing is a programming error.
|
||||||
|
.expect("Failed to store prefill tokens");
|
||||||
|
// We can have a prefill with the following structure:
|
||||||
|
//
|
||||||
|
// |---| From the prefix cache.
|
||||||
|
// A B C D E F G
|
||||||
|
//|--------| Found in the trie during insertion.
|
||||||
|
//
|
||||||
|
// This means that while processing this request there was a
|
||||||
|
// partially overlapping request that had A..=E in its
|
||||||
|
// prefill. In this case we need to free the blocks D E.
|
||||||
|
if prefix_len > allocation.cached_prefix_len {
|
||||||
|
self.free_blocks.extend(
|
||||||
|
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||||
|
..prefix_len / self.block_size as usize],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free non-prefill blocks.
|
||||||
|
self.free_blocks
|
||||||
|
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||||
|
} else {
|
||||||
|
self.free_blocks.extend(blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RadixAllocation {
|
||||||
|
prefix_node: NodeId,
|
||||||
|
cached_prefix_len: usize,
|
||||||
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||||
|
//
|
||||||
|
// The trie is optimized for prefix caching:
|
||||||
|
//
|
||||||
|
// - A normal radix trie stores discrete values. In this radix trie,
|
||||||
|
// inserting *abc* with value *xyz* will also enable lookup for
|
||||||
|
// *a* (*x*) and *ab* (*xy*).
|
||||||
|
// - As a result, every value is required to have the same length as
|
||||||
|
// the key.
|
||||||
|
// - We store additional information in each node, such as last access
|
||||||
|
// time and a reference count.
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum TrieError {
|
||||||
|
InvalidNodeId,
|
||||||
|
RefCountUnderflow,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type NodeId = DefaultKey;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RadixTrie {
|
||||||
|
/// Identifier of the root nod.
|
||||||
|
root: DefaultKey,
|
||||||
|
|
||||||
|
/// Leave node identifiers ordered by increasing recency.
|
||||||
|
leaves: BTreeSet<(u64, NodeId)>,
|
||||||
|
|
||||||
|
/// All trie nodes.
|
||||||
|
nodes: SlotMap<NodeId, TrieNode>,
|
||||||
|
|
||||||
|
/// Time as a monotonically increating counter to avoid the system
|
||||||
|
/// call that a real time lookup would require.
|
||||||
|
time: u64,
|
||||||
|
|
||||||
|
/// All blocks need to be aligned with this
|
||||||
|
block_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RadixTrie {
|
||||||
|
/// Construct a new radix trie.
|
||||||
|
pub fn new(block_size: usize) -> Self {
|
||||||
|
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||||
|
let mut nodes = SlotMap::new();
|
||||||
|
let root = nodes.insert(root);
|
||||||
|
RadixTrie {
|
||||||
|
leaves: BTreeSet::new(),
|
||||||
|
nodes,
|
||||||
|
root,
|
||||||
|
time: 0,
|
||||||
|
block_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the prefix of the given tokens.
|
||||||
|
///
|
||||||
|
/// The blocks corresponding to the part of the prefix that could be found
|
||||||
|
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||||
|
/// Returns the identifier of the trie node that contains the longest
|
||||||
|
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||||
|
/// reference count.
|
||||||
|
///
|
||||||
|
/// Using this method will update the access time of the traversed nodes.
|
||||||
|
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
|
self.time += 1;
|
||||||
|
self.find_(self.root, key, blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find worker.
|
||||||
|
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
|
if key.len() >= self.block_size {
|
||||||
|
let node_key = hash(&key[..self.block_size]);
|
||||||
|
if let Some(&child_id) = node.children.get(&node_key) {
|
||||||
|
self.update_access_time(child_id);
|
||||||
|
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||||
|
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||||
|
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||||
|
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||||
|
|
||||||
|
let key = &key[shared_prefix_len..];
|
||||||
|
if !key.is_empty() {
|
||||||
|
node_id = self.find_(child_id, key, blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrease the reference count of a node.
|
||||||
|
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||||
|
// We don't care about refcounting for root, since it will never
|
||||||
|
// be evicted.
|
||||||
|
if node_id == self.root {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.ok_or(TrieError::InvalidNodeId)?;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
return Err(TrieError::RefCountUnderflow);
|
||||||
|
}
|
||||||
|
|
||||||
|
node.ref_count -= 1;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
assert!(
|
||||||
|
node.children.is_empty(),
|
||||||
|
"Nodes with children must have refcount > 0"
|
||||||
|
);
|
||||||
|
|
||||||
|
self.leaves.insert((node.last_accessed, node_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increase the reference count of a node.
|
||||||
|
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||||
|
if node_id == self.root {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.ok_or(TrieError::InvalidNodeId)?;
|
||||||
|
if node.ref_count == 0 {
|
||||||
|
self.leaves.remove(&(node.last_accessed, node_id));
|
||||||
|
}
|
||||||
|
node.ref_count += 1;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evict `n_blocks` from the trie.
|
||||||
|
///
|
||||||
|
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||||
|
/// not enough blocks could be evicted.
|
||||||
|
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||||
|
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||||
|
// it's a programming error in the trie implementation, not a user
|
||||||
|
// error caused by e.g. an invalid argument.
|
||||||
|
|
||||||
|
// TODO: add some bookkeeping in the future to check whether we can
|
||||||
|
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||||
|
// evicting prefixes from the cache in such a case.
|
||||||
|
let mut evicted = Vec::new();
|
||||||
|
tracing::debug!("Evicting in search of {n_blocks}");
|
||||||
|
|
||||||
|
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||||
|
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||||
|
tracing::debug!("Evicting node {node_id:?} ");
|
||||||
|
|
||||||
|
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||||
|
assert_eq!(
|
||||||
|
node.ref_count, 0,
|
||||||
|
"Leaf must have refcount of 0, got {}",
|
||||||
|
node.ref_count
|
||||||
|
);
|
||||||
|
|
||||||
|
if blocks_needed >= node.blocks.len() {
|
||||||
|
// We need to evict the whole node if we need more blocks than it has.
|
||||||
|
let node = self.remove_node(node_id);
|
||||||
|
evicted.extend(node.blocks);
|
||||||
|
|
||||||
|
if evicted.len() >= n_blocks {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The node has more blocks than needed, so we'll just remove
|
||||||
|
// the required number of blocks and leave the remaining blocks
|
||||||
|
// untouched.
|
||||||
|
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||||
|
|
||||||
|
let truncate_blocks = node.blocks.len() - blocks_needed;
|
||||||
|
let truncate_tokens = truncate_blocks * self.block_size;
|
||||||
|
node.key.truncate(truncate_tokens);
|
||||||
|
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||||
|
self.leaves.insert((last_access, node_id));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
evicted
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a prefill along with its blocks.
|
||||||
|
///
|
||||||
|
/// This method returns the length of the prefix that was already
|
||||||
|
/// in the trie. E.g. if the length is 10, this means that for
|
||||||
|
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||||
|
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||||
|
self.time += 1;
|
||||||
|
let common = self.insert_(self.root, tokens, blocks)?;
|
||||||
|
Ok(common)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insertion worker.
|
||||||
|
fn insert_(
|
||||||
|
&mut self,
|
||||||
|
node_id: NodeId,
|
||||||
|
tokens: &[u32],
|
||||||
|
blocks: &[u32],
|
||||||
|
) -> Result<usize, TrieError> {
|
||||||
|
// TODO: in the future we may want to check that the blocks match for
|
||||||
|
// the part of the prefix that is already in the trie to detect
|
||||||
|
// mismatches.
|
||||||
|
|
||||||
|
assert_eq!(tokens.len(), blocks.len() * self.block_size);
|
||||||
|
|
||||||
|
let node_key = hash(&tokens[..self.block_size]);
|
||||||
|
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
|
||||||
|
self.update_access_time(child_id);
|
||||||
|
let child = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(child_id)
|
||||||
|
// Unwrap here, since failure is a bug.
|
||||||
|
.expect("Child node does not exist");
|
||||||
|
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
||||||
|
|
||||||
|
// We are done, the prefix is already in the trie.
|
||||||
|
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
||||||
|
return Ok(shared_prefix_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The node's prefix is a prefix of the insertion prefix.
|
||||||
|
if shared_prefix_len == child.key.len() {
|
||||||
|
return Ok(shared_prefix_len
|
||||||
|
+ self.insert_(
|
||||||
|
child_id,
|
||||||
|
&tokens[shared_prefix_len..],
|
||||||
|
&blocks[shared_prefix_len / self.block_size..],
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The node's prefix and the insertion prefix only match partially,
|
||||||
|
// split the node to just contain the matching part. Then insert the
|
||||||
|
// remainder of the prefix into the node again
|
||||||
|
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||||
|
let key = &tokens[shared_prefix_len..];
|
||||||
|
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
||||||
|
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||||
|
} else {
|
||||||
|
self.add_node(node_id, tokens, blocks);
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||||
|
// We have to make the current node a child to ensure that its
|
||||||
|
// properties and node id stay the same.
|
||||||
|
|
||||||
|
// This funcion unwraps, an invalid node_id is a programming error.
|
||||||
|
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.expect("Node to-be split does not exist");
|
||||||
|
let mut parent_key = node.key.split_off(prefix_len);
|
||||||
|
let prefix_blocks = prefix_len / self.block_size;
|
||||||
|
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
|
||||||
|
|
||||||
|
// Move first part of the prefix to the parent. We swap to avoid
|
||||||
|
// an allocation + copy for both splits of the key/blocks.
|
||||||
|
std::mem::swap(&mut node.key, &mut parent_key);
|
||||||
|
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||||
|
|
||||||
|
let node_key = hash(&node.key[..self.block_size]);
|
||||||
|
|
||||||
|
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||||
|
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||||
|
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||||
|
|
||||||
|
// Reborrow to make the borrow checker happy.
|
||||||
|
let node = self
|
||||||
|
.nodes
|
||||||
|
.get_mut(node_id)
|
||||||
|
.expect("Node to-be split does not exist");
|
||||||
|
node.parent = Some(parent_id);
|
||||||
|
|
||||||
|
parent_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a node and add it to the parent.
|
||||||
|
fn add_node(
|
||||||
|
&mut self,
|
||||||
|
parent_id: NodeId,
|
||||||
|
key: impl Into<Vec<u32>>,
|
||||||
|
blocks: impl Into<Vec<u32>>,
|
||||||
|
) -> NodeId {
|
||||||
|
let key = key.into();
|
||||||
|
let blocks = blocks.into();
|
||||||
|
let first = hash(&key[..self.block_size]);
|
||||||
|
|
||||||
|
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||||
|
let child_id = self.nodes.insert(child);
|
||||||
|
|
||||||
|
self.add_node_to_parent(parent_id, first, child_id);
|
||||||
|
self.leaves.insert((self.time, child_id));
|
||||||
|
|
||||||
|
child_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a node to the parent.
|
||||||
|
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
|
||||||
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
|
if parent.children.insert(hash, child_id).is_none() {
|
||||||
|
// Only increase reference count if child does not replace another child.
|
||||||
|
self.incref(parent_id)
|
||||||
|
.expect("Failed to increase parent refcount");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a node from the trie.
|
||||||
|
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||||
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
|
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||||
|
assert!(
|
||||||
|
node.children.is_empty(),
|
||||||
|
"Tried to remove a node with {} children",
|
||||||
|
node.children.len()
|
||||||
|
);
|
||||||
|
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||||
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
|
|
||||||
|
let node_key = hash(&node.key[..self.block_size]);
|
||||||
|
parent.children.remove(&node_key);
|
||||||
|
self.decref(parent_id)
|
||||||
|
.expect("Failed to decrease parent refcount");
|
||||||
|
node
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_access_time(&mut self, node_id: NodeId) {
|
||||||
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
|
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
||||||
|
|
||||||
|
// Update the ordered leaves set if the node is a leave.
|
||||||
|
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
||||||
|
self.leaves.insert((self.time, node_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
node.last_accessed = self.time;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[doc(hidden)]
|
||||||
|
/// Print debugging output for the trie.
|
||||||
|
///
|
||||||
|
/// In contrast to `Debug` nicely formatted.
|
||||||
|
pub fn print_debug(&self) {
|
||||||
|
self.print_debug_(self.root, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||||
|
let node = &self.nodes[node_id];
|
||||||
|
eprintln!(
|
||||||
|
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||||
|
" ".repeat(indent),
|
||||||
|
node_id,
|
||||||
|
node.key,
|
||||||
|
node.blocks,
|
||||||
|
node.ref_count,
|
||||||
|
node.last_accessed,
|
||||||
|
node.parent,
|
||||||
|
node.children
|
||||||
|
);
|
||||||
|
for child_id in self.nodes[node_id].children.values() {
|
||||||
|
self.print_debug_(*child_id, indent + 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||||
|
self.root
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trie node.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TrieNode {
|
||||||
|
blocks: Vec<u32>,
|
||||||
|
children: HashMap<u64, NodeId>,
|
||||||
|
key: Vec<u32>,
|
||||||
|
last_accessed: u64,
|
||||||
|
parent: Option<NodeId>,
|
||||||
|
ref_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrieNode {
|
||||||
|
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
||||||
|
TrieNode {
|
||||||
|
children: HashMap::new(),
|
||||||
|
key,
|
||||||
|
blocks,
|
||||||
|
last_accessed,
|
||||||
|
parent,
|
||||||
|
ref_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
||||||
|
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
||||||
|
// NOTE: this is the case because the child node was chosen based on
|
||||||
|
// matching the first character of the key/prefix.
|
||||||
|
assert!(full > 0, "Prefixes must at least share 1 token");
|
||||||
|
(full / block_size) * block_size
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_block_size() {
|
||||||
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||||
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_block_size_non_aligned() {
|
||||||
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
|
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
|
assert_eq!(allocation.prefix_len, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_reuses_prefixes() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 12, None);
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.blocks, allocation.slots);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_collects_older_prefixes_first() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 7, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||||
|
assert_eq!(allocation1.prefix_len, 0);
|
||||||
|
|
||||||
|
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||||
|
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||||
|
assert_eq!(allocation2.prefix_len, 0);
|
||||||
|
|
||||||
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
|
||||||
|
// We should get the blocks of the first allocation, since they are more recent.
|
||||||
|
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||||
|
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||||
|
assert_eq!(allocation3.prefix_len, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_frees_fully_overlapping_prefills() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 10, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
|
||||||
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
|
||||||
|
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation3.prefix_len, 4);
|
||||||
|
|
||||||
|
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||||
|
assert_eq!(cache.free_blocks.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_frees_partially_overlapping_prefills() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 20, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||||
|
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||||
|
assert_eq!(allocation1.prefix_len, 0);
|
||||||
|
|
||||||
|
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||||
|
|
||||||
|
let allocation2 = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||||
|
assert_eq!(allocation2.prefix_len, 2);
|
||||||
|
|
||||||
|
let allocation3 = cache
|
||||||
|
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation3.prefix_len, 2);
|
||||||
|
|
||||||
|
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
||||||
|
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||||
|
|
||||||
|
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
||||||
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
|
||||||
|
let allocation4 = cache
|
||||||
|
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||||
|
assert_eq!(allocation4.prefix_len, 6);
|
||||||
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
|
||||||
|
let allocation5 = cache
|
||||||
|
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||||
|
assert_eq!(allocation5.prefix_len, 6);
|
||||||
|
assert_eq!(cache.free_blocks.len(), 11);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_insertions_have_correct_prefix_len() {
|
||||||
|
let mut trie = RadixTrie::new(1);
|
||||||
|
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Already exists.
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
||||||
|
|
||||||
|
// Completely new at root-level
|
||||||
|
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Contains full prefix, but longer.
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
||||||
|
|
||||||
|
// Shares partial prefix, we need a split.
|
||||||
|
assert_eq!(
|
||||||
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
|
.unwrap(),
|
||||||
|
4
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_insertions_block_size() {
|
||||||
|
let mut trie = RadixTrie::new(2);
|
||||||
|
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Already exists.
|
||||||
|
// But needs to be block_size aligned
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
||||||
|
|
||||||
|
// Completely new at root-level
|
||||||
|
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Contains full prefix, but longer.
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
||||||
|
|
||||||
|
// Shares partial prefix, we need a split.
|
||||||
|
assert_eq!(
|
||||||
|
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
||||||
|
.unwrap(),
|
||||||
|
2
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_get_returns_correct_blocks() {
|
||||||
|
let mut trie = RadixTrie::new(1);
|
||||||
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
|
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
trie.find(&[0], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![1, 2, 3]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||||
|
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_evict_removes_correct_blocks() {
|
||||||
|
let mut trie = RadixTrie::new(1);
|
||||||
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
|
.unwrap();
|
||||||
|
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||||
|
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||||
|
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
|
||||||
|
// Remove less than the leave blocks.
|
||||||
|
assert_eq!(trie.evict(1), vec![7]);
|
||||||
|
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
||||||
|
|
||||||
|
// Refresh other leaf.
|
||||||
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
|
|
||||||
|
// Remove the leave blocks exactly.
|
||||||
|
assert_eq!(trie.evict(2), vec![5, 6]);
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||||
|
|
||||||
|
trie.find(&[1, 2, 3], &mut blocks);
|
||||||
|
|
||||||
|
// Remove more than the leave blocks.
|
||||||
|
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
||||||
|
blocks.clear();
|
||||||
|
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||||
|
assert_eq!(blocks, vec![0, 1]);
|
||||||
|
|
||||||
|
// Clear out the whole trie.
|
||||||
|
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user