Control prefill and decode batch size separately (#6)

This commit is contained in:
Karol Damaszke 2024-01-02 18:21:01 +01:00 committed by GitHub
parent 1be2d9a8ec
commit 252ccde104
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 8 deletions

View File

@ -74,7 +74,8 @@ Environment Variables Added:
| PROF_STEP | interger | 5 | Control profile step | add -e in docker run command |
| PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command |
| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command |
| BATCH_BUCKET_SIZE | integer | 8 | Batch size will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
</div>

View File

@ -50,11 +50,18 @@ impl Infer {
max_concurrent_requests: usize,
requires_padding: bool,
max_input_length: u32,
max_total_tokens: u32,
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, max_input_length, 16, window_size);
let queue = Queue::new(
requires_padding,
max_input_length,
max_total_tokens,
16,
window_size
);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});

View File

@ -37,6 +37,7 @@ impl Queue {
pub(crate) fn new(
requires_padding: bool,
max_input_length: u32,
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>
) -> Self {
@ -47,6 +48,7 @@ impl Queue {
tokio::spawn(queue_task(
requires_padding,
max_input_length,
max_total_tokens,
block_size,
window_size,
queue_receiver,
@ -96,11 +98,18 @@ impl Queue {
async fn queue_task(
requires_padding: bool,
max_input_length: u32,
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, max_input_length, block_size, window_size);
let mut state = State::new(
requires_padding,
max_input_length,
max_total_tokens,
block_size,
window_size
);
while let Some(cmd) = receiver.recv().await {
match cmd {
@ -138,9 +147,12 @@ struct State {
/// Whether the model is using padding
requires_padding: bool,
/// Maximum inpult length, required for padding scenario
/// Maximum input length, required for padding scenario
max_input_length: u32,
/// Maximum input and output length, required for padding scenario
max_total_tokens: u32,
/// Paged Attention block size
block_size: u32,
@ -152,6 +164,7 @@ impl State {
fn new(
requires_padding: bool,
max_input_length: u32,
max_total_tokens: u32,
block_size: u32,
window_size: Option<u32>
) -> Self {
@ -161,6 +174,7 @@ impl State {
next_batch_id: 0,
requires_padding,
max_input_length,
max_total_tokens,
block_size,
window_size,
}
@ -218,7 +232,7 @@ impl State {
if self.requires_padding {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length;
} else {
// pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
@ -227,7 +241,9 @@ impl State {
}
if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
// We pad to max total tokens in the Python shards
// We need to take these padding tokens into the equation
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length);
} else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,

View File

@ -596,6 +596,7 @@ pub async fn run(
max_concurrent_requests,
shard_info.requires_padding,
max_input_length as u32,
max_total_tokens as u32,
shard_info.window_size,
generation_health,
);

View File

@ -36,6 +36,7 @@ from loguru import logger
tracer = trace.get_tracer(__name__)
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
def trace(txt):
@ -234,7 +235,11 @@ class CausalLMBatch(Batch):
top_n_tokens = [r.data.top_n_tokens for r in requests]
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.data.parameters for r in requests], batches[0].next_token_chooser.device, batches[0].next_token_chooser.dtype)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
[r.data.parameters for r in requests],
batches[0].next_token_chooser.device,
batches[0].next_token_chooser.dtype
)
htorch.core.mark_step()
@ -286,7 +291,7 @@ class CausalLMBatch(Batch):
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
# this means that we cannot shift inputs to the left after a long input sequence
# was filtered out
new_bs = round_up(len(requests), BATCH_BUCKET_SIZE)
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
dummy_inputs = ["?"] * (new_bs - len(requests))
tokenized_inputs = tokenizer(
[r.data.inputs for r in requests] + dummy_inputs,