Add specific args for batch

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-03 11:03:47 +00:00
parent e6a8d33902
commit bfb8e03e9f
No known key found for this signature in database
2 changed files with 38 additions and 18 deletions

View File

@ -56,9 +56,11 @@ pub struct LlamacppConfig {
pub model_gguf: String, pub model_gguf: String,
pub n_ctx: usize, pub n_ctx: usize,
pub max_batch_total_tokens: usize, pub max_batch_total_tokens: usize,
pub max_physical_batch_total_tokens: usize,
pub max_batch_size: usize, pub max_batch_size: usize,
pub batch_timeout: Duration, pub batch_timeout: Duration,
pub n_threads: usize, pub n_threads: usize,
pub n_threads_batch: usize,
pub n_gpu_layers: usize, pub n_gpu_layers: usize,
pub split_mode: LlamacppSplitMode, pub split_mode: LlamacppSplitMode,
pub numa: LlamacppNuma, pub numa: LlamacppNuma,
@ -173,10 +175,10 @@ impl Llamacpp {
let mut params = bindings::llama_context_default_params(); let mut params = bindings::llama_context_default_params();
params.n_ctx = conf.n_ctx as _; params.n_ctx = conf.n_ctx as _;
params.n_batch = conf.max_batch_total_tokens as _; params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ? params.n_ubatch = conf.max_physical_batch_total_tokens as _;
params.n_seq_max = conf.max_batch_size as _; params.n_seq_max = conf.max_batch_size as _;
params.n_threads = conf.n_threads as _; params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads as _; // TODO ? params.n_threads_batch = conf.n_threads_batch as _;
params.defrag_thold = conf.defrag_threshold; params.defrag_thold = conf.defrag_threshold;
params.offload_kqv = conf.offload_kqv; params.offload_kqv = conf.offload_kqv;
params.flash_attn = conf.flash_attention; params.flash_attn = conf.flash_attention;

View File

@ -28,10 +28,14 @@ struct Args {
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
n_ctx: usize, n_ctx: usize,
/// Number of threads to use for inference. /// Number of threads to use for generation.
#[clap(long, env)] #[clap(long, env)]
n_threads: Option<usize>, n_threads: Option<usize>,
/// Number of threads to use for batch processing.
#[clap(long, env)]
n_threads_batch: Option<usize>,
/// Number of layers to store in VRAM. /// Number of layers to store in VRAM.
#[clap(default_value = "0", long, env)] #[clap(default_value = "0", long, env)]
n_gpu_layers: usize, n_gpu_layers: usize,
@ -89,10 +93,14 @@ struct Args {
// #[clap(default_value = "4096", long, env)] // #[clap(default_value = "4096", long, env)]
// max_batch_prefill_tokens: u32, // max_batch_prefill_tokens: u32,
/// Maximum tokens within a batch /// Maximum number of tokens that can be submitted within a batch
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_total_tokens: usize, max_batch_total_tokens: usize,
/// Maximum number of tokens within a batch
#[clap(long, env)]
max_physical_batch_total_tokens: Option<usize>,
// #[clap(default_value = "20", long, env)] // #[clap(default_value = "20", long, env)]
// max_waiting_tokens: usize, // max_waiting_tokens: usize,
@ -159,6 +167,14 @@ async fn main() -> Result<(), RouterError> {
Some(0) | None => num_cpus::get(), Some(0) | None => num_cpus::get(),
Some(threads) => threads, Some(threads) => threads,
}; };
let n_threads_batch = match args.n_threads_batch {
Some(0) | None => n_threads,
Some(threads) => threads,
};
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
None => args.max_batch_total_tokens,
Some(size) => size,
};
if args.max_input_tokens >= args.max_total_tokens { if args.max_input_tokens >= args.max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
@ -199,20 +215,22 @@ async fn main() -> Result<(), RouterError> {
let (backend, ok) = LlamacppBackend::new( let (backend, ok) = LlamacppBackend::new(
LlamacppConfig { LlamacppConfig {
model_gguf: args.model_gguf, model_gguf: args.model_gguf,
n_ctx: args.n_ctx, n_ctx: args.n_ctx,
n_threads: n_threads, n_threads: n_threads,
n_gpu_layers: args.n_gpu_layers, n_threads_batch: n_threads_batch,
split_mode: args.split_mode, n_gpu_layers: args.n_gpu_layers,
defrag_threshold: args.defrag_threshold, split_mode: args.split_mode,
numa: args.numa, defrag_threshold: args.defrag_threshold,
use_mmap: args.use_mmap, numa: args.numa,
use_mlock: args.use_mlock, use_mmap: args.use_mmap,
flash_attention: args.flash_attention, use_mlock: args.use_mlock,
offload_kqv: args.offload_kqv, flash_attention: args.flash_attention,
max_batch_total_tokens: args.max_batch_total_tokens, offload_kqv: args.offload_kqv,
max_batch_size: args.max_batch_size, max_batch_total_tokens: args.max_batch_total_tokens,
batch_timeout: tokio::time::Duration::from_millis(5), max_physical_batch_total_tokens: max_physical_batch_total_tokens,
max_batch_size: args.max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(5),
}, },
tokenizer, tokenizer,
); );