mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fixes and some extra comments
This commit is contained in:
parent
08aee68f79
commit
7552483dde
@ -240,7 +240,6 @@ impl<B: BatchType> Infer<B> {
|
|||||||
/// Batches requests and sends them to the inference server
|
/// Batches requests and sends them to the inference server
|
||||||
async fn batching_task<B: BatchType>(
|
async fn batching_task<B: BatchType>(
|
||||||
mut client: ShardedClient,
|
mut client: ShardedClient,
|
||||||
// max_batch_size: usize,
|
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue<B>,
|
queue: Queue<B>,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
|
@ -108,8 +108,11 @@ async fn queue_task<B: BatchType>(
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct BatchingConfig {
|
pub(crate) struct BatchingConfig {
|
||||||
|
/// Upper bound on number of requests in a batch
|
||||||
pub(crate) size_limit: usize,
|
pub(crate) size_limit: usize,
|
||||||
|
/// Maximum batch "weight" at any point of time (takes sequence lengths into account)
|
||||||
pub(crate) weight_limit: usize,
|
pub(crate) weight_limit: usize,
|
||||||
|
/// Maximum weight of individual prefill batches
|
||||||
pub(crate) prefill_weight_limit: usize,
|
pub(crate) prefill_weight_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,6 +217,8 @@ impl BatchType for FlashBatch {
|
|||||||
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
for (bs, (ol, il, _)) in tree.iter().rev().enumerate() {
|
||||||
let this_ol = *ol;
|
let this_ol = *ol;
|
||||||
in_sum += *il;
|
in_sum += *il;
|
||||||
|
// Only need to check segments with output_len > current_output_len
|
||||||
|
// will have been checked in a prior iteration
|
||||||
if this_ol <= current_output_len {
|
if this_ol <= current_output_len {
|
||||||
// Check if we breach max space for this segment
|
// Check if we breach max space for this segment
|
||||||
let token_count = in_sum + (bs + 1) * this_ol;
|
let token_count = in_sum + (bs + 1) * this_ol;
|
||||||
@ -244,7 +249,7 @@ impl BatchType for PaddedBatch {
|
|||||||
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
fn batch_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
|
||||||
let (max_input_length, max_output_length) = max_in_out_lengths;
|
let (max_input_length, max_output_length) = max_in_out_lengths;
|
||||||
let max_seq_len = max_input_length + max_output_length;
|
let max_seq_len = max_input_length + max_output_length;
|
||||||
// Memory requirement roughly propotionall to batch_size * seq_len^2
|
// Memory requirement roughly proportional to batch_size * seq_len^2
|
||||||
batch_size * max_seq_len.pow(2)
|
batch_size * max_seq_len.pow(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -635,7 +635,7 @@ async fn do_run<B: BatchType>(
|
|||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
_batch_type: B,
|
batch_type: B,
|
||||||
) {
|
) {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -701,7 +701,7 @@ async fn do_run<B: BatchType>(
|
|||||||
max_prefill_weight,
|
max_prefill_weight,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
FlashBatch{}
|
batch_type,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
Loading…
Reference in New Issue
Block a user