mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-10 19:32:06 +00:00
Control prefill and decode batch size separately (#6)
This commit is contained in:
parent
1be2d9a8ec
commit
252ccde104
@ -74,7 +74,8 @@ Environment Variables Added:
|
|||||||
| PROF_STEP | interger | 5 | Control profile step | add -e in docker run command |
|
| PROF_STEP | interger | 5 | Control profile step | add -e in docker run command |
|
||||||
| PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command |
|
| PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command |
|
||||||
| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command |
|
| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command |
|
||||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||||
|
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
@ -50,11 +50,18 @@ impl Infer {
|
|||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, max_input_length, 16, window_size);
|
let queue = Queue::new(
|
||||||
|
requires_padding,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
16,
|
||||||
|
window_size
|
||||||
|
);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
@ -37,6 +37,7 @@ impl Queue {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>
|
window_size: Option<u32>
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -47,6 +48,7 @@ impl Queue {
|
|||||||
tokio::spawn(queue_task(
|
tokio::spawn(queue_task(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
@ -96,11 +98,18 @@ impl Queue {
|
|||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, max_input_length, block_size, window_size);
|
let mut state = State::new(
|
||||||
|
requires_padding,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
block_size,
|
||||||
|
window_size
|
||||||
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -138,9 +147,12 @@ struct State {
|
|||||||
/// Whether the model is using padding
|
/// Whether the model is using padding
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
|
||||||
/// Maximum inpult length, required for padding scenario
|
/// Maximum input length, required for padding scenario
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
|
|
||||||
|
/// Maximum input and output length, required for padding scenario
|
||||||
|
max_total_tokens: u32,
|
||||||
|
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
@ -152,6 +164,7 @@ impl State {
|
|||||||
fn new(
|
fn new(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>
|
window_size: Option<u32>
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -161,6 +174,7 @@ impl State {
|
|||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
}
|
}
|
||||||
@ -218,7 +232,7 @@ impl State {
|
|||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length
|
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length;
|
||||||
} else {
|
} else {
|
||||||
// pad to block size
|
// pad to block size
|
||||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||||
@ -227,7 +241,9 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
// We pad to max total tokens in the Python shards
|
||||||
|
// We need to take these padding tokens into the equation
|
||||||
|
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length);
|
||||||
} else {
|
} else {
|
||||||
let max_new_tokens = match self.window_size {
|
let max_new_tokens = match self.window_size {
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
@ -596,6 +596,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
max_input_length as u32,
|
max_input_length as u32,
|
||||||
|
max_total_tokens as u32,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
generation_health,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
@ -36,6 +36,7 @@ from loguru import logger
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||||
|
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||||
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
|
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
|
||||||
|
|
||||||
def trace(txt):
|
def trace(txt):
|
||||||
@ -234,7 +235,11 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
top_n_tokens = [r.data.top_n_tokens for r in requests]
|
top_n_tokens = [r.data.top_n_tokens for r in requests]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.data.parameters for r in requests], batches[0].next_token_chooser.device, batches[0].next_token_chooser.dtype)
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
[r.data.parameters for r in requests],
|
||||||
|
batches[0].next_token_chooser.device,
|
||||||
|
batches[0].next_token_chooser.dtype
|
||||||
|
)
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
@ -286,7 +291,7 @@ class CausalLMBatch(Batch):
|
|||||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||||
# this means that we cannot shift inputs to the left after a long input sequence
|
# this means that we cannot shift inputs to the left after a long input sequence
|
||||||
# was filtered out
|
# was filtered out
|
||||||
new_bs = round_up(len(requests), BATCH_BUCKET_SIZE)
|
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||||
dummy_inputs = ["?"] * (new_bs - len(requests))
|
dummy_inputs = ["?"] * (new_bs - len(requests))
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
[r.data.inputs for r in requests] + dummy_inputs,
|
[r.data.inputs for r in requests] + dummy_inputs,
|
||||||
|
Loading…
Reference in New Issue
Block a user