mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fixes and some extra comments
This commit is contained in:
parent
08aee68f79
commit
7552483dde
@ -240,7 +240,6 @@ impl<B: BatchType> Infer<B> {
|
||||
/// Batches requests and sends them to the inference server
|
||||
async fn batching_task<B: BatchType>(
|
||||
mut client: ShardedClient,
|
||||
// max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
queue: Queue<B>,
|
||||
shared: Arc<Shared>,
|
||||
|
@ -108,8 +108,11 @@ async fn queue_task<B: BatchType>(
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BatchingConfig {
|
||||
/// Upper bound on number of requests in a batch
|
||||
pub(crate) size_limit: usize,
|
||||
/// Maximum batch "weight" at any point of time (takes sequence lengths into account)
|
||||
pub(crate) weight_limit: usize,
|
||||
/// Maximum weight of individual prefill batches
|
||||
pub(crate) prefill_weight_limit: usize,
|
||||
}
|
||||
|
||||
@ -214,6 +217,8 @@ impl BatchType for FlashBatch {
|
||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||
let this_ol = *ol;
|
||||
in_sum += *il;
|
||||
// Only need to check segments with output_len > current_output_len
|
||||
// will have been checked in a prior iteration
|
||||
if this_ol <= current_output_len {
|
||||
// Check if we breach max space for this segment
|
||||
let token_count = in_sum + (bs + 1) * this_ol;
|
||||
@ -244,7 +249,7 @@ impl BatchType for PaddedBatch {
|
||||
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
||||
let (max_input_length, max_output_length) = max_in_out_lengths;
|
||||
let max_seq_len = max_input_length + max_output_length;
|
||||
// Memory requirement roughly propotionall to batch_size * seq_len^2
|
||||
// Memory requirement roughly proportional to batch_size * seq_len^2
|
||||
batch_size * max_seq_len.pow(2)
|
||||
}
|
||||
|
||||
|
@ -635,7 +635,7 @@ async fn do_run<B: BatchType>(
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
_batch_type: B,
|
||||
batch_type: B,
|
||||
) {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -701,7 +701,7 @@ async fn do_run<B: BatchType>(
|
||||
max_prefill_weight,
|
||||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
FlashBatch{}
|
||||
batch_type,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
Loading…
Reference in New Issue
Block a user