mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 10:52:07 +00:00
Control prefill and decode batch size separately (#6)
This commit is contained in:
parent
1be2d9a8ec
commit
252ccde104
@ -74,7 +74,8 @@ Environment Variables Added:
|
||||
| PROF_STEP | interger | 5 | Control profile step | add -e in docker run command |
|
||||
| PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command |
|
||||
| LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command |
|
||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command |
|
||||
|
||||
</div>
|
||||
|
||||
|
@ -50,11 +50,18 @@ impl Infer {
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
window_size: Option<u32>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, max_input_length, 16, window_size);
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
16,
|
||||
window_size
|
||||
);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
@ -37,6 +37,7 @@ impl Queue {
|
||||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>
|
||||
) -> Self {
|
||||
@ -47,6 +48,7 @@ impl Queue {
|
||||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
block_size,
|
||||
window_size,
|
||||
queue_receiver,
|
||||
@ -96,11 +98,18 @@ impl Queue {
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, max_input_length, block_size, window_size);
|
||||
let mut state = State::new(
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
block_size,
|
||||
window_size
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
@ -138,9 +147,12 @@ struct State {
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Maximum inpult length, required for padding scenario
|
||||
/// Maximum input length, required for padding scenario
|
||||
max_input_length: u32,
|
||||
|
||||
/// Maximum input and output length, required for padding scenario
|
||||
max_total_tokens: u32,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
@ -152,6 +164,7 @@ impl State {
|
||||
fn new(
|
||||
requires_padding: bool,
|
||||
max_input_length: u32,
|
||||
max_total_tokens: u32,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>
|
||||
) -> Self {
|
||||
@ -161,6 +174,7 @@ impl State {
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
block_size,
|
||||
window_size,
|
||||
}
|
||||
@ -218,7 +232,7 @@ impl State {
|
||||
if self.requires_padding {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length;
|
||||
} else {
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
@ -227,7 +241,9 @@ impl State {
|
||||
}
|
||||
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
// We pad to max total tokens in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length);
|
||||
} else {
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
|
@ -596,6 +596,7 @@ pub async fn run(
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
max_input_length as u32,
|
||||
max_total_tokens as u32,
|
||||
shard_info.window_size,
|
||||
generation_health,
|
||||
);
|
||||
|
@ -36,6 +36,7 @@ from loguru import logger
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8))
|
||||
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4))
|
||||
TRACE_FILENAME = os.environ.get('TRACE_FILENAME')
|
||||
|
||||
def trace(txt):
|
||||
@ -234,7 +235,11 @@ class CausalLMBatch(Batch):
|
||||
|
||||
top_n_tokens = [r.data.top_n_tokens for r in requests]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.data.parameters for r in requests], batches[0].next_token_chooser.device, batches[0].next_token_chooser.dtype)
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
[r.data.parameters for r in requests],
|
||||
batches[0].next_token_chooser.device,
|
||||
batches[0].next_token_chooser.dtype
|
||||
)
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
@ -286,7 +291,7 @@ class CausalLMBatch(Batch):
|
||||
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
|
||||
# this means that we cannot shift inputs to the left after a long input sequence
|
||||
# was filtered out
|
||||
new_bs = round_up(len(requests), BATCH_BUCKET_SIZE)
|
||||
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
|
||||
dummy_inputs = ["?"] * (new_bs - len(requests))
|
||||
tokenized_inputs = tokenizer(
|
||||
[r.data.inputs for r in requests] + dummy_inputs,
|
||||
|
Loading…
Reference in New Issue
Block a user