mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Use max_batch_total_tokens
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
bd0cc9905c
commit
3eb4823f3e
@ -24,7 +24,7 @@ use tracing::{instrument};
|
|||||||
pub struct LlamacppConfig {
|
pub struct LlamacppConfig {
|
||||||
pub model_gguf: String,
|
pub model_gguf: String,
|
||||||
pub n_ctx: u32,
|
pub n_ctx: u32,
|
||||||
pub batch_size: usize,
|
pub max_batch_total_tokens: u32,
|
||||||
pub batch_timeout: Duration,
|
pub batch_timeout: Duration,
|
||||||
pub n_threads: i32,
|
pub n_threads: i32,
|
||||||
pub use_mmap: bool,
|
pub use_mmap: bool,
|
||||||
@ -142,7 +142,7 @@ impl Llamacpp {
|
|||||||
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
|
||||||
}
|
}
|
||||||
let batch = unsafe {
|
let batch = unsafe {
|
||||||
bindings::llama_batch_init(4096, 0, 5)
|
bindings::llama_batch_init(conf.max_batch_total_tokens as _, 0, 1)
|
||||||
};
|
};
|
||||||
// TODO check batch
|
// TODO check batch
|
||||||
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
|
Ok(Llamacpp{model, ctx, vocab, n_ctx, batch})
|
||||||
@ -313,21 +313,25 @@ impl LlamacppBackend {
|
|||||||
let (sync_tx, sync_rx) = mpsc::channel();
|
let (sync_tx, sync_rx) = mpsc::channel();
|
||||||
|
|
||||||
spawn(async move {
|
spawn(async move {
|
||||||
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||||
Ok(None) => break, // closed
|
Ok(None) => break, // closed
|
||||||
Ok(Some(request)) => {
|
Ok(Some(request)) => {
|
||||||
requests.push(request);
|
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
|
||||||
if requests.len() >= conf.batch_size {
|
|
||||||
let _ = sync_tx.send(requests);
|
let _ = sync_tx.send(requests);
|
||||||
requests = Vec::new();
|
n_tokens = request.input_ids.len();
|
||||||
|
requests = vec![request];
|
||||||
|
} else {
|
||||||
|
requests.push(request);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
if !requests.is_empty() {
|
if !requests.is_empty() {
|
||||||
let _ = sync_tx.send(requests);
|
let _ = sync_tx.send(requests);
|
||||||
|
n_tokens = 0;
|
||||||
requests = Vec::new();
|
requests = Vec::new();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,8 +68,11 @@ struct Args {
|
|||||||
// waiting_served_ratio: f32,
|
// waiting_served_ratio: f32,
|
||||||
// #[clap(default_value = "4096", long, env)]
|
// #[clap(default_value = "4096", long, env)]
|
||||||
// max_batch_prefill_tokens: u32,
|
// max_batch_prefill_tokens: u32,
|
||||||
// #[clap(long, env)]
|
|
||||||
// max_batch_total_tokens: Option<u32>,
|
/// Maximum tokens within a batch
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
|
||||||
// #[clap(default_value = "20", long, env)]
|
// #[clap(default_value = "20", long, env)]
|
||||||
// max_waiting_tokens: usize,
|
// max_waiting_tokens: usize,
|
||||||
// #[clap(long, env)]
|
// #[clap(long, env)]
|
||||||
@ -155,14 +158,14 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
|
|
||||||
let (backend, ok) = LlamacppBackend::new(
|
let (backend, ok) = LlamacppBackend::new(
|
||||||
LlamacppConfig {
|
LlamacppConfig {
|
||||||
model_gguf: args.model_gguf,
|
model_gguf: args.model_gguf,
|
||||||
n_ctx: args.n_ctx,
|
n_ctx: args.n_ctx,
|
||||||
n_threads: args.n_threads,
|
n_threads: args.n_threads,
|
||||||
use_mmap: args.use_mmap,
|
use_mmap: args.use_mmap,
|
||||||
use_mlock: args.use_mlock,
|
use_mlock: args.use_mlock,
|
||||||
flash_attention: args.flash_attention,
|
flash_attention: args.flash_attention,
|
||||||
batch_size: 5,
|
max_batch_total_tokens: args.max_batch_total_tokens,
|
||||||
batch_timeout: tokio::time::Duration::from_millis(100),
|
batch_timeout: tokio::time::Duration::from_millis(100),
|
||||||
},
|
},
|
||||||
tokenizer,
|
tokenizer,
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user