mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Handle max_batch_size
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
3eb4823f3e
commit
e7facf692f
@ -25,6 +25,7 @@ pub struct LlamacppConfig {
|
||||
pub model_gguf: String,
|
||||
pub n_ctx: u32,
|
||||
pub max_batch_total_tokens: u32,
|
||||
pub max_batch_size: Option<usize>,
|
||||
pub batch_timeout: Duration,
|
||||
pub n_threads: i32,
|
||||
pub use_mmap: bool,
|
||||
@ -320,13 +321,22 @@ impl LlamacppBackend {
|
||||
match timeout(conf.batch_timeout, rx.recv()).await {
|
||||
Ok(None) => break, // closed
|
||||
Ok(Some(request)) => {
|
||||
if let Some(max_batch_size) = conf.max_batch_size {
|
||||
if requests.len() + 1 == max_batch_size {
|
||||
requests.push(request);
|
||||
let _ = sync_tx.send(requests);
|
||||
n_tokens = 0;
|
||||
requests = Vec::new();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if n_tokens + request.input_ids.len() > conf.max_batch_total_tokens as usize {
|
||||
let _ = sync_tx.send(requests);
|
||||
n_tokens = request.input_ids.len();
|
||||
requests = vec![request];
|
||||
} else {
|
||||
requests.push(request);
|
||||
continue;
|
||||
}
|
||||
requests.push(request);
|
||||
},
|
||||
Err(_) => {
|
||||
if !requests.is_empty() {
|
||||
|
@ -75,8 +75,10 @@ struct Args {
|
||||
|
||||
// #[clap(default_value = "20", long, env)]
|
||||
// max_waiting_tokens: usize,
|
||||
// #[clap(long, env)]
|
||||
// max_batch_size: Option<usize>,
|
||||
|
||||
/// Maximum number of requests per batch
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
@ -165,6 +167,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
use_mlock: args.use_mlock,
|
||||
flash_attention: args.flash_attention,
|
||||
max_batch_total_tokens: args.max_batch_total_tokens,
|
||||
max_batch_size: args.max_batch_size,
|
||||
batch_timeout: tokio::time::Duration::from_millis(100),
|
||||
},
|
||||
tokenizer,
|
||||
|
Loading…
Reference in New Issue
Block a user