diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index bf45a67f..63a93757 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -43,6 +43,15 @@ impl FromStr for LlamacppSplitMode { } } +#[derive(Debug, Clone, Copy, clap::ValueEnum)] +pub enum LlamacppNuma { + Disabled, + Distribute, + Isolate, + Numactl, + Mirror, +} + pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: usize, @@ -52,6 +61,7 @@ pub struct LlamacppConfig { pub n_threads: usize, pub n_gpu_layers: usize, pub split_mode: LlamacppSplitMode, + pub numa: LlamacppNuma, pub defrag_threshold: f32, pub use_mmap: bool, pub use_mlock: bool, @@ -387,7 +397,13 @@ impl LlamacppBackend { INIT.call_once(|| unsafe { bindings::llama_log_set(Some(llamacpp_log_callback), std::ptr::null_mut()); bindings::llama_backend_init(); - bindings::llama_numa_init(bindings::GGML_NUMA_STRATEGY_NUMACTL); // TODO add option & test + bindings::llama_numa_init(match conf.numa { + LlamacppNuma::Disabled => bindings::GGML_NUMA_STRATEGY_DISABLED, + LlamacppNuma::Distribute => bindings::GGML_NUMA_STRATEGY_DISTRIBUTE, + LlamacppNuma::Isolate => bindings::GGML_NUMA_STRATEGY_ISOLATE, + LlamacppNuma::Numactl => bindings::GGML_NUMA_STRATEGY_NUMACTL, + LlamacppNuma::Mirror => bindings::GGML_NUMA_STRATEGY_MIRROR, + }); }); let (status_tx, status_rx) = watch::channel(false); diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 5fb23d17..a8283c13 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -1,6 +1,6 @@ mod backend; -use backend::{LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; +use backend::{LlamacppNuma, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; use clap::{Parser}; use text_generation_router::{logging, server, usage_stats}; use thiserror::Error; @@ -44,6 +44,10 @@ struct Args { #[clap(default_value = "-1.0", long, env)] defrag_threshold: f32, + /// Setup NUMA optimizations. + #[clap(default_value = "Disabled", value_enum, long, env)] + numa: LlamacppNuma, + /// Whether to use memory mapping. #[clap(default_value = "true", long, env)] use_mmap: bool, @@ -193,6 +197,7 @@ async fn main() -> Result<(), RouterError> { n_gpu_layers: args.n_gpu_layers, split_mode: args.split_mode, defrag_threshold: args.defrag_threshold, + numa: args.numa, use_mmap: args.use_mmap, use_mlock: args.use_mlock, flash_attention: args.flash_attention,