diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 63a93757..6f8cc59d 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -65,6 +65,7 @@ pub struct LlamacppConfig { pub defrag_threshold: f32, pub use_mmap: bool, pub use_mlock: bool, + pub offload_kqv: bool, pub flash_attention: bool, } @@ -177,6 +178,7 @@ impl Llamacpp { params.n_threads = conf.n_threads as _; params.n_threads_batch = conf.n_threads as _; // TODO ? params.defrag_thold = conf.defrag_threshold; + params.offload_kqv = conf.offload_kqv; params.flash_attn = conf.flash_attention; params.no_perf = true; bindings::llama_init_from_model(model, params) diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 1c7c5e4c..b5eec467 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -56,6 +56,10 @@ struct Args { #[clap(default_value = "false", long, env)] use_mlock: bool, + /// Enable offloading of KQV operations to the GPU. + #[clap(default_value = "false", long, env)] + offload_kqv: bool, + /// Enable flash attention for faster inference. (EXPERIMENTAL) #[clap(default_value = "true", long, env)] flash_attention: bool, @@ -201,6 +205,7 @@ async fn main() -> Result<(), RouterError> { use_mmap: args.use_mmap, use_mlock: args.use_mlock, flash_attention: args.flash_attention, + offload_kqv: args.offload_kqv, max_batch_total_tokens: args.max_batch_total_tokens, max_batch_size: args.max_batch_size, batch_timeout: tokio::time::Duration::from_millis(5),