diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 0d3c3950..d7bc31de 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -20,6 +20,28 @@ use tokio::time::{Duration, Instant, timeout}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, info, warn, error, trace}; use tracing::{instrument}; +use std::str::FromStr; + +#[derive(Debug, Clone, Copy)] +pub enum LlamacppSplitMode { + GPU(usize), + Layer, + Row, +} + +impl FromStr for LlamacppSplitMode { + type Err = String; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "layer" => Ok(LlamacppSplitMode::Layer), + "row" => Ok(LlamacppSplitMode::Row), + _ => match s.parse::() { + Ok(n) => Ok(LlamacppSplitMode::GPU(n)), + Err(_) => Err(format!("Choose a GPU number or `layer` or `row`")), + } + } + } +} pub struct LlamacppConfig { pub model_gguf: String, @@ -28,6 +50,8 @@ pub struct LlamacppConfig { pub max_batch_size: usize, pub batch_timeout: Duration, pub n_threads: usize, + pub n_gpu_layers: usize, + pub split_mode: LlamacppSplitMode, pub use_mmap: bool, pub use_mlock: bool, pub flash_attention: bool, @@ -116,6 +140,18 @@ impl Llamacpp { let model = unsafe { let mut params = bindings::llama_model_default_params(); + params.n_gpu_layers = conf.n_gpu_layers as _; + params.split_mode = match conf.split_mode { + LlamacppSplitMode::GPU(_) => bindings::LLAMA_SPLIT_MODE_NONE, + LlamacppSplitMode::Layer => bindings::LLAMA_SPLIT_MODE_LAYER, + LlamacppSplitMode::Row => bindings::LLAMA_SPLIT_MODE_ROW, + }; + params.main_gpu = match conf.split_mode { + LlamacppSplitMode::GPU(n) => n as _, + _ => 0, + }; + info!(?params.split_mode); + info!(?params.main_gpu); params.use_mmap = conf.use_mmap; params.use_mlock = conf.use_mlock; bindings::llama_model_load_from_file(gguf.as_ptr(), params) diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 7eae8315..fe7c1cd1 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -1,6 +1,6 @@ mod backend; -use backend::{LlamacppConfig, LlamacppBackend, BackendError}; +use backend::{LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; use clap::{Parser}; use text_generation_router::{logging, server, usage_stats}; use thiserror::Error; @@ -32,6 +32,14 @@ struct Args { #[clap(default_value = "1", long, env)] n_threads: usize, + /// Number of layers to store in VRAM. + #[clap(default_value = "0", long, env)] + n_gpu_layers: usize, + + /// Split the model across multiple GPUs. + #[clap(default_value = "Layer", value_enum, long, env)] + split_mode: LlamacppSplitMode, + #[clap(default_value = "true", long, env)] /// Whether to use memory mapping. use_mmap: bool, @@ -178,6 +186,8 @@ async fn main() -> Result<(), RouterError> { model_gguf: args.model_gguf, n_ctx: args.n_ctx, n_threads: args.n_threads, + n_gpu_layers: args.n_gpu_layers, + split_mode: args.split_mode, use_mmap: args.use_mmap, use_mlock: args.use_mlock, flash_attention: args.flash_attention,