From bfb8e03e9f3f5799c8d081a7490c78f36f976d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Mon, 3 Feb 2025 11:03:47 +0000 Subject: [PATCH] Add specific args for batch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/src/backend.rs | 6 ++-- backends/llamacpp/src/main.rs | 50 ++++++++++++++++++++++---------- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index c76f0308..bf4b19e3 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -56,9 +56,11 @@ pub struct LlamacppConfig { pub model_gguf: String, pub n_ctx: usize, pub max_batch_total_tokens: usize, + pub max_physical_batch_total_tokens: usize, pub max_batch_size: usize, pub batch_timeout: Duration, pub n_threads: usize, + pub n_threads_batch: usize, pub n_gpu_layers: usize, pub split_mode: LlamacppSplitMode, pub numa: LlamacppNuma, @@ -173,10 +175,10 @@ impl Llamacpp { let mut params = bindings::llama_context_default_params(); params.n_ctx = conf.n_ctx as _; params.n_batch = conf.max_batch_total_tokens as _; - params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ? + params.n_ubatch = conf.max_physical_batch_total_tokens as _; params.n_seq_max = conf.max_batch_size as _; params.n_threads = conf.n_threads as _; - params.n_threads_batch = conf.n_threads as _; // TODO ? + params.n_threads_batch = conf.n_threads_batch as _; params.defrag_thold = conf.defrag_threshold; params.offload_kqv = conf.offload_kqv; params.flash_attn = conf.flash_attention; diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index f3e81782..55881b13 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -28,10 +28,14 @@ struct Args { #[clap(default_value = "4096", long, env)] n_ctx: usize, - /// Number of threads to use for inference. + /// Number of threads to use for generation. #[clap(long, env)] n_threads: Option, + /// Number of threads to use for batch processing. + #[clap(long, env)] + n_threads_batch: Option, + /// Number of layers to store in VRAM. #[clap(default_value = "0", long, env)] n_gpu_layers: usize, @@ -89,10 +93,14 @@ struct Args { // #[clap(default_value = "4096", long, env)] // max_batch_prefill_tokens: u32, - /// Maximum tokens within a batch + /// Maximum number of tokens that can be submitted within a batch #[clap(default_value = "4096", long, env)] max_batch_total_tokens: usize, + /// Maximum number of tokens within a batch + #[clap(long, env)] + max_physical_batch_total_tokens: Option, + // #[clap(default_value = "20", long, env)] // max_waiting_tokens: usize, @@ -159,6 +167,14 @@ async fn main() -> Result<(), RouterError> { Some(0) | None => num_cpus::get(), Some(threads) => threads, }; + let n_threads_batch = match args.n_threads_batch { + Some(0) | None => n_threads, + Some(threads) => threads, + }; + let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens { + None => args.max_batch_total_tokens, + Some(size) => size, + }; if args.max_input_tokens >= args.max_total_tokens { return Err(RouterError::ArgumentValidation( "`max_input_tokens` must be < `max_total_tokens`".to_string(), @@ -199,20 +215,22 @@ async fn main() -> Result<(), RouterError> { let (backend, ok) = LlamacppBackend::new( LlamacppConfig { - model_gguf: args.model_gguf, - n_ctx: args.n_ctx, - n_threads: n_threads, - n_gpu_layers: args.n_gpu_layers, - split_mode: args.split_mode, - defrag_threshold: args.defrag_threshold, - numa: args.numa, - use_mmap: args.use_mmap, - use_mlock: args.use_mlock, - flash_attention: args.flash_attention, - offload_kqv: args.offload_kqv, - max_batch_total_tokens: args.max_batch_total_tokens, - max_batch_size: args.max_batch_size, - batch_timeout: tokio::time::Duration::from_millis(5), + model_gguf: args.model_gguf, + n_ctx: args.n_ctx, + n_threads: n_threads, + n_threads_batch: n_threads_batch, + n_gpu_layers: args.n_gpu_layers, + split_mode: args.split_mode, + defrag_threshold: args.defrag_threshold, + numa: args.numa, + use_mmap: args.use_mmap, + use_mlock: args.use_mlock, + flash_attention: args.flash_attention, + offload_kqv: args.offload_kqv, + max_batch_total_tokens: args.max_batch_total_tokens, + max_physical_batch_total_tokens: max_physical_batch_total_tokens, + max_batch_size: args.max_batch_size, + batch_timeout: tokio::time::Duration::from_millis(5), }, tokenizer, );