fixes and some extra comments

This commit is contained in:
Nick Hill 2023-05-01 21:22:18 +01:00
parent 08aee68f79
commit 7552483dde
3 changed files with 8 additions and 4 deletions

View File

@ -240,7 +240,6 @@ impl<B: BatchType> Infer<B> {
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
async fn batching_task<B: BatchType>( async fn batching_task<B: BatchType>(
mut client: ShardedClient, mut client: ShardedClient,
// max_batch_size: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue<B>, queue: Queue<B>,
shared: Arc<Shared>, shared: Arc<Shared>,

View File

@ -108,8 +108,11 @@ async fn queue_task<B: BatchType>(
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct BatchingConfig { pub(crate) struct BatchingConfig {
/// Upper bound on number of requests in a batch
pub(crate) size_limit: usize, pub(crate) size_limit: usize,
/// Maximum batch "weight" at any point of time (takes sequence lengths into account)
pub(crate) weight_limit: usize, pub(crate) weight_limit: usize,
/// Maximum weight of individual prefill batches
pub(crate) prefill_weight_limit: usize, pub(crate) prefill_weight_limit: usize,
} }
@ -214,6 +217,8 @@ impl BatchType for FlashBatch {
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() { for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
let this_ol = *ol; let this_ol = *ol;
in_sum += *il; in_sum += *il;
// Only need to check segments with output_len > current_output_len
// will have been checked in a prior iteration
if this_ol <= current_output_len { if this_ol <= current_output_len {
// Check if we breach max space for this segment // Check if we breach max space for this segment
let token_count = in_sum + (bs + 1) * this_ol; let token_count = in_sum + (bs + 1) * this_ol;
@ -244,7 +249,7 @@ impl BatchType for PaddedBatch {
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize { fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
let (max_input_length, max_output_length) = max_in_out_lengths; let (max_input_length, max_output_length) = max_in_out_lengths;
let max_seq_len = max_input_length + max_output_length; let max_seq_len = max_input_length + max_output_length;
// Memory requirement roughly propotionall to batch_size * seq_len^2 // Memory requirement roughly proportional to batch_size * seq_len^2
batch_size * max_seq_len.pow(2) batch_size * max_seq_len.pow(2)
} }

View File

@ -635,7 +635,7 @@ async fn do_run<B: BatchType>(
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
_batch_type: B, batch_type: B,
) { ) {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -701,7 +701,7 @@ async fn do_run<B: BatchType>(
max_prefill_weight, max_prefill_weight,
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
FlashBatch{} batch_type,
); );
// Duration buckets // Duration buckets