mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 02:42:05 +00:00
feat(backend): rely on multi consumer queue to scheduler workers
This commit is contained in:
parent
84eead219a
commit
5a85661661
49
Cargo.lock
generated
49
Cargo.lock
generated
@ -142,6 +142,18 @@ version = "0.7.6"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-channel"
|
||||||
|
version = "2.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
|
||||||
|
dependencies = [
|
||||||
|
"concurrent-queue",
|
||||||
|
"event-listener-strategy",
|
||||||
|
"futures-core",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-rustls"
|
name = "async-rustls"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
@ -758,6 +770,15 @@ dependencies = [
|
|||||||
"static_assertions",
|
"static_assertions",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "concurrent-queue"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "console"
|
name = "console"
|
||||||
version = "0.15.8"
|
version = "0.15.8"
|
||||||
@ -1158,6 +1179,27 @@ dependencies = [
|
|||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "event-listener"
|
||||||
|
version = "5.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
|
||||||
|
dependencies = [
|
||||||
|
"concurrent-queue",
|
||||||
|
"parking",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "event-listener-strategy"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1"
|
||||||
|
dependencies = [
|
||||||
|
"event-listener",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "exr"
|
name = "exr"
|
||||||
version = "1.72.0"
|
version = "1.72.0"
|
||||||
@ -2922,6 +2964,12 @@ dependencies = [
|
|||||||
"unicode-width",
|
"unicode-width",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking"
|
||||||
|
version = "2.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot"
|
name = "parking_lot"
|
||||||
version = "0.12.3"
|
version = "0.12.3"
|
||||||
@ -4219,6 +4267,7 @@ dependencies = [
|
|||||||
name = "text-generation-backend-llamacpp"
|
name = "text-generation-backend-llamacpp"
|
||||||
version = "2.4.1-dev0"
|
version = "2.4.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-channel",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 4.5.20",
|
"clap 4.5.20",
|
||||||
"cmake",
|
"cmake",
|
||||||
|
@ -7,6 +7,7 @@ homepage.workspace = true
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
async-channel = "2.3"
|
||||||
clap = { version = "4.5.19", features = ["derive"] }
|
clap = { version = "4.5.19", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
num_cpus = "1"
|
num_cpus = "1"
|
||||||
|
@ -2,6 +2,7 @@ use crate::ffi::{
|
|||||||
create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend,
|
create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
};
|
};
|
||||||
|
use async_channel::{unbounded as mpmc_unbounded, Receiver as MpmcReceiver, Sender as MpmcSender};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use log::warn;
|
use log::warn;
|
||||||
@ -19,7 +20,6 @@ use text_generation_router::{FinishReason, Token};
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::Semaphore;
|
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
@ -102,18 +102,6 @@ pub enum LlamaCppBackendError {
|
|||||||
ModelInitializationFailed(PathBuf, String),
|
ModelInitializationFailed(PathBuf, String),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LlamaCppWorker {
|
|
||||||
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlamaCppWorker {
|
|
||||||
fn submit(&self, ctx: GenerationContext, sx: UnboundedSender<InferResult>) {
|
|
||||||
if let Err(err) = self.sender.send((ctx, sx)) {
|
|
||||||
// TODO: What do we do?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct LlamaCppBackend {
|
pub struct LlamaCppBackend {
|
||||||
scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender<InferResult>)>,
|
scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
scheduler_handle: JoinHandle<()>,
|
scheduler_handle: JoinHandle<()>,
|
||||||
@ -141,29 +129,26 @@ impl LlamaCppBackend {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
|
// Allocate the multi-consumer queue to orchestrate all the workers
|
||||||
|
let (backlog_submitter, backlog_receiver) = mpmc_unbounded();
|
||||||
|
|
||||||
// Allocate all the workers
|
// Allocate all the workers
|
||||||
let streams = cores_allocation
|
let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
|
||||||
.iter()
|
cores_allocation.iter().for_each(|affinity| {
|
||||||
.map(
|
match Self::allocate_worker(path, num_cores_per_instance as u32) {
|
||||||
|affinity| match Self::allocate_worker(path, num_cores_per_instance as u32) {
|
Ok(worker) => {
|
||||||
Ok(worker) => {
|
let tokenizer = Arc::clone(&tokenizer);
|
||||||
let tokenizer = Arc::clone(&tokenizer);
|
let affinity = affinity.clone().collect::<Vec<_>>();
|
||||||
let (sender, receiver) = channel();
|
let backlog_receiver = backlog_receiver.clone();
|
||||||
let affinity = affinity.clone().collect::<Vec<_>>();
|
spawn(move || worker_loop(worker, affinity, tokenizer, backlog_receiver));
|
||||||
spawn(move || worker_loop(worker, affinity, tokenizer, receiver));
|
}
|
||||||
|
Err(e) => {}
|
||||||
Ok(LlamaCppWorker { sender })
|
}
|
||||||
}
|
});
|
||||||
Err(e) => Err(e),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
|
||||||
|
|
||||||
// Start the scheduler loop
|
// Start the scheduler loop
|
||||||
let (scheduler_sender, scheduler_receiver) = unbounded_channel();
|
let (scheduler_sender, scheduler_receiver) = unbounded_channel();
|
||||||
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, streams));
|
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, backlog_submitter));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
scheduler_sender,
|
scheduler_sender,
|
||||||
scheduler_handle,
|
scheduler_handle,
|
||||||
@ -263,24 +248,16 @@ fn llama_generate_callback(
|
|||||||
|
|
||||||
async fn scheduler_loop(
|
async fn scheduler_loop(
|
||||||
mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
mut workers: Vec<LlamaCppWorker>,
|
backlog: MpmcSender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
) {
|
) {
|
||||||
// Semaphore allows us to wait for a worker to become available
|
|
||||||
let permits = Semaphore::new(workers.len());
|
|
||||||
|
|
||||||
// Let's receive incoming requests
|
// Let's receive incoming requests
|
||||||
loop {
|
loop {
|
||||||
match queue.recv().await {
|
match queue.recv().await {
|
||||||
None => break,
|
None => break,
|
||||||
Some((ctx, sender)) => {
|
Some((ctx, sender)) => {
|
||||||
let permit = permits.try_acquire();
|
if let Err(e) = backlog.send((ctx, sender)).await {
|
||||||
if let Err(err) = permit {
|
todo!("What do we do")
|
||||||
let _ = sender.send(Err(InferError::Overloaded(err)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can unwrap because we wouldn't have a semaphore available otherwise
|
|
||||||
let worker = workers.pop().unwrap();
|
|
||||||
worker.submit(ctx, sender);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -290,7 +267,7 @@ fn worker_loop(
|
|||||||
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
||||||
affinity: Vec<usize>,
|
affinity: Vec<usize>,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
backlog: MpmcReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
) {
|
) {
|
||||||
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
||||||
tokenizers::utils::parallelism::set_parallelism(false);
|
tokenizers::utils::parallelism::set_parallelism(false);
|
||||||
@ -299,7 +276,7 @@ fn worker_loop(
|
|||||||
set_numactl_core_affinity(&affinity);
|
set_numactl_core_affinity(&affinity);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if let Ok((generation, stream)) = backlog.recv() {
|
if let Ok((generation, stream)) = backlog.recv_blocking() {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let generation_params = generation.generation_params; // copy
|
let generation_params = generation.generation_params; // copy
|
||||||
let sampling_params = generation.sampling_params; // copy
|
let sampling_params = generation.sampling_params; // copy
|
||||||
|
Loading…
Reference in New Issue
Block a user