mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Add specific args for batch
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
e6a8d33902
commit
bfb8e03e9f
@ -56,9 +56,11 @@ pub struct LlamacppConfig {
|
|||||||
pub model_gguf: String,
|
pub model_gguf: String,
|
||||||
pub n_ctx: usize,
|
pub n_ctx: usize,
|
||||||
pub max_batch_total_tokens: usize,
|
pub max_batch_total_tokens: usize,
|
||||||
|
pub max_physical_batch_total_tokens: usize,
|
||||||
pub max_batch_size: usize,
|
pub max_batch_size: usize,
|
||||||
pub batch_timeout: Duration,
|
pub batch_timeout: Duration,
|
||||||
pub n_threads: usize,
|
pub n_threads: usize,
|
||||||
|
pub n_threads_batch: usize,
|
||||||
pub n_gpu_layers: usize,
|
pub n_gpu_layers: usize,
|
||||||
pub split_mode: LlamacppSplitMode,
|
pub split_mode: LlamacppSplitMode,
|
||||||
pub numa: LlamacppNuma,
|
pub numa: LlamacppNuma,
|
||||||
@ -173,10 +175,10 @@ impl Llamacpp {
|
|||||||
let mut params = bindings::llama_context_default_params();
|
let mut params = bindings::llama_context_default_params();
|
||||||
params.n_ctx = conf.n_ctx as _;
|
params.n_ctx = conf.n_ctx as _;
|
||||||
params.n_batch = conf.max_batch_total_tokens as _;
|
params.n_batch = conf.max_batch_total_tokens as _;
|
||||||
params.n_ubatch = conf.max_batch_total_tokens as _; // TODO ?
|
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
|
||||||
params.n_seq_max = conf.max_batch_size as _;
|
params.n_seq_max = conf.max_batch_size as _;
|
||||||
params.n_threads = conf.n_threads as _;
|
params.n_threads = conf.n_threads as _;
|
||||||
params.n_threads_batch = conf.n_threads as _; // TODO ?
|
params.n_threads_batch = conf.n_threads_batch as _;
|
||||||
params.defrag_thold = conf.defrag_threshold;
|
params.defrag_thold = conf.defrag_threshold;
|
||||||
params.offload_kqv = conf.offload_kqv;
|
params.offload_kqv = conf.offload_kqv;
|
||||||
params.flash_attn = conf.flash_attention;
|
params.flash_attn = conf.flash_attention;
|
||||||
|
@ -28,10 +28,14 @@ struct Args {
|
|||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
n_ctx: usize,
|
n_ctx: usize,
|
||||||
|
|
||||||
/// Number of threads to use for inference.
|
/// Number of threads to use for generation.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
n_threads: Option<usize>,
|
n_threads: Option<usize>,
|
||||||
|
|
||||||
|
/// Number of threads to use for batch processing.
|
||||||
|
#[clap(long, env)]
|
||||||
|
n_threads_batch: Option<usize>,
|
||||||
|
|
||||||
/// Number of layers to store in VRAM.
|
/// Number of layers to store in VRAM.
|
||||||
#[clap(default_value = "0", long, env)]
|
#[clap(default_value = "0", long, env)]
|
||||||
n_gpu_layers: usize,
|
n_gpu_layers: usize,
|
||||||
@ -89,10 +93,14 @@ struct Args {
|
|||||||
// #[clap(default_value = "4096", long, env)]
|
// #[clap(default_value = "4096", long, env)]
|
||||||
// max_batch_prefill_tokens: u32,
|
// max_batch_prefill_tokens: u32,
|
||||||
|
|
||||||
/// Maximum tokens within a batch
|
/// Maximum number of tokens that can be submitted within a batch
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
max_batch_total_tokens: usize,
|
max_batch_total_tokens: usize,
|
||||||
|
|
||||||
|
/// Maximum number of tokens within a batch
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_physical_batch_total_tokens: Option<usize>,
|
||||||
|
|
||||||
// #[clap(default_value = "20", long, env)]
|
// #[clap(default_value = "20", long, env)]
|
||||||
// max_waiting_tokens: usize,
|
// max_waiting_tokens: usize,
|
||||||
|
|
||||||
@ -159,6 +167,14 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
Some(0) | None => num_cpus::get(),
|
Some(0) | None => num_cpus::get(),
|
||||||
Some(threads) => threads,
|
Some(threads) => threads,
|
||||||
};
|
};
|
||||||
|
let n_threads_batch = match args.n_threads_batch {
|
||||||
|
Some(0) | None => n_threads,
|
||||||
|
Some(threads) => threads,
|
||||||
|
};
|
||||||
|
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
|
||||||
|
None => args.max_batch_total_tokens,
|
||||||
|
Some(size) => size,
|
||||||
|
};
|
||||||
if args.max_input_tokens >= args.max_total_tokens {
|
if args.max_input_tokens >= args.max_total_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
@ -199,20 +215,22 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
|
|
||||||
let (backend, ok) = LlamacppBackend::new(
|
let (backend, ok) = LlamacppBackend::new(
|
||||||
LlamacppConfig {
|
LlamacppConfig {
|
||||||
model_gguf: args.model_gguf,
|
model_gguf: args.model_gguf,
|
||||||
n_ctx: args.n_ctx,
|
n_ctx: args.n_ctx,
|
||||||
n_threads: n_threads,
|
n_threads: n_threads,
|
||||||
n_gpu_layers: args.n_gpu_layers,
|
n_threads_batch: n_threads_batch,
|
||||||
split_mode: args.split_mode,
|
n_gpu_layers: args.n_gpu_layers,
|
||||||
defrag_threshold: args.defrag_threshold,
|
split_mode: args.split_mode,
|
||||||
numa: args.numa,
|
defrag_threshold: args.defrag_threshold,
|
||||||
use_mmap: args.use_mmap,
|
numa: args.numa,
|
||||||
use_mlock: args.use_mlock,
|
use_mmap: args.use_mmap,
|
||||||
flash_attention: args.flash_attention,
|
use_mlock: args.use_mlock,
|
||||||
offload_kqv: args.offload_kqv,
|
flash_attention: args.flash_attention,
|
||||||
max_batch_total_tokens: args.max_batch_total_tokens,
|
offload_kqv: args.offload_kqv,
|
||||||
max_batch_size: args.max_batch_size,
|
max_batch_total_tokens: args.max_batch_total_tokens,
|
||||||
batch_timeout: tokio::time::Duration::from_millis(5),
|
max_physical_batch_total_tokens: max_physical_batch_total_tokens,
|
||||||
|
max_batch_size: args.max_batch_size,
|
||||||
|
batch_timeout: tokio::time::Duration::from_millis(5),
|
||||||
},
|
},
|
||||||
tokenizer,
|
tokenizer,
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user