feat(backend): rely on multi consumer queue to scheduler workers

This commit is contained in:
Morgan Funtowicz 2024-11-22 13:32:56 +01:00
parent 84eead219a
commit 5a85661661
3 changed files with 71 additions and 44 deletions

49
Cargo.lock generated
View File

@ -142,6 +142,18 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "async-channel"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
dependencies = [
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]] [[package]]
name = "async-rustls" name = "async-rustls"
version = "0.3.0" version = "0.3.0"
@ -758,6 +770,15 @@ dependencies = [
"static_assertions", "static_assertions",
] ]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "console" name = "console"
version = "0.15.8" version = "0.15.8"
@ -1158,6 +1179,27 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "event-listener"
version = "5.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]] [[package]]
name = "exr" name = "exr"
version = "1.72.0" version = "1.72.0"
@ -2922,6 +2964,12 @@ dependencies = [
"unicode-width", "unicode-width",
] ]
[[package]]
name = "parking"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.3" version = "0.12.3"
@ -4219,6 +4267,7 @@ dependencies = [
name = "text-generation-backend-llamacpp" name = "text-generation-backend-llamacpp"
version = "2.4.1-dev0" version = "2.4.1-dev0"
dependencies = [ dependencies = [
"async-channel",
"async-trait", "async-trait",
"clap 4.5.20", "clap 4.5.20",
"cmake", "cmake",

View File

@ -7,6 +7,7 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
async-channel = "2.3"
clap = { version = "4.5.19", features = ["derive"] } clap = { version = "4.5.19", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
num_cpus = "1" num_cpus = "1"

View File

@ -2,6 +2,7 @@ use crate::ffi::{
create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend, create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend,
SamplingParams, SamplingParams,
}; };
use async_channel::{unbounded as mpmc_unbounded, Receiver as MpmcReceiver, Sender as MpmcSender};
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use log::warn; use log::warn;
@ -19,7 +20,6 @@ use text_generation_router::{FinishReason, Token};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::Semaphore;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
@ -102,18 +102,6 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String), ModelInitializationFailed(PathBuf, String),
} }
struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
}
impl LlamaCppWorker {
fn submit(&self, ctx: GenerationContext, sx: UnboundedSender<InferResult>) {
if let Err(err) = self.sender.send((ctx, sx)) {
// TODO: What do we do?
}
}
}
pub struct LlamaCppBackend { pub struct LlamaCppBackend {
scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender<InferResult>)>, scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender<InferResult>)>,
scheduler_handle: JoinHandle<()>, scheduler_handle: JoinHandle<()>,
@ -141,29 +129,26 @@ impl LlamaCppBackend {
)); ));
} }
let cores_allocation = get_cores_allocation(num_cores_per_instance as usize); // Allocate the multi-consumer queue to orchestrate all the workers
let (backlog_submitter, backlog_receiver) = mpmc_unbounded();
// Allocate all the workers // Allocate all the workers
let streams = cores_allocation let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
.iter() cores_allocation.iter().for_each(|affinity| {
.map( match Self::allocate_worker(path, num_cores_per_instance as u32) {
|affinity| match Self::allocate_worker(path, num_cores_per_instance as u32) { Ok(worker) => {
Ok(worker) => { let tokenizer = Arc::clone(&tokenizer);
let tokenizer = Arc::clone(&tokenizer); let affinity = affinity.clone().collect::<Vec<_>>();
let (sender, receiver) = channel(); let backlog_receiver = backlog_receiver.clone();
let affinity = affinity.clone().collect::<Vec<_>>(); spawn(move || worker_loop(worker, affinity, tokenizer, backlog_receiver));
spawn(move || worker_loop(worker, affinity, tokenizer, receiver)); }
Err(e) => {}
Ok(LlamaCppWorker { sender }) }
} });
Err(e) => Err(e),
},
)
.collect::<Result<Vec<_>, _>>()?;
// Start the scheduler loop // Start the scheduler loop
let (scheduler_sender, scheduler_receiver) = unbounded_channel(); let (scheduler_sender, scheduler_receiver) = unbounded_channel();
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, streams)); let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, backlog_submitter));
Ok(Self { Ok(Self {
scheduler_sender, scheduler_sender,
scheduler_handle, scheduler_handle,
@ -263,24 +248,16 @@ fn llama_generate_callback(
async fn scheduler_loop( async fn scheduler_loop(
mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender<InferResult>)>, mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
mut workers: Vec<LlamaCppWorker>, backlog: MpmcSender<(GenerationContext, UnboundedSender<InferResult>)>,
) { ) {
// Semaphore allows us to wait for a worker to become available
let permits = Semaphore::new(workers.len());
// Let's receive incoming requests // Let's receive incoming requests
loop { loop {
match queue.recv().await { match queue.recv().await {
None => break, None => break,
Some((ctx, sender)) => { Some((ctx, sender)) => {
let permit = permits.try_acquire(); if let Err(e) = backlog.send((ctx, sender)).await {
if let Err(err) = permit { todo!("What do we do")
let _ = sender.send(Err(InferError::Overloaded(err)));
} }
// We can unwrap because we wouldn't have a semaphore available otherwise
let worker = workers.pop().unwrap();
worker.submit(ctx, sender);
} }
} }
} }
@ -290,7 +267,7 @@ fn worker_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>, mut backend: UniquePtr<LlamaCppWorkerFrontend>,
affinity: Vec<usize>, affinity: Vec<usize>,
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>, backlog: MpmcReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
) { ) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism // This loop will mostly decode single token at every step, so no need to rely on parallelism
tokenizers::utils::parallelism::set_parallelism(false); tokenizers::utils::parallelism::set_parallelism(false);
@ -299,7 +276,7 @@ fn worker_loop(
set_numactl_core_affinity(&affinity); set_numactl_core_affinity(&affinity);
loop { loop {
if let Ok((generation, stream)) = backlog.recv() { if let Ok((generation, stream)) = backlog.recv_blocking() {
let start = Instant::now(); let start = Instant::now();
let generation_params = generation.generation_params; // copy let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy let sampling_params = generation.sampling_params; // copy