diff --git a/README.md b/README.md index d033ffcd..b187bed1 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,8 @@ Environment Variables Added: | PROF_STEP | interger | 5 | Control profile step | add -e in docker run command | | PROF_PATH | string | /root/text-generation-inference | Define profile folder | add -e in docker run command | | LIMIT_HPU_GRAPH | True/False | False | Skip HPU graph usage for prefill to save memory | add -e in docker run command | -| BATCH_BUCKET_SIZE | integer | 8 | Batch size will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | +| BATCH_BUCKET_SIZE | integer | 8 | Batch size for decode operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | +| PREFILL_BATCH_BUCKET_SIZE | integer | 4 | Batch size for prefill operation will be rounded to the nearest multiple of this number. This limits the number of cached graphs | add -e in docker run command | diff --git a/router/src/infer.rs b/router/src/infer.rs index 79c5eb17..6195a4fa 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -50,11 +50,18 @@ impl Infer { max_concurrent_requests: usize, requires_padding: bool, max_input_length: u32, + max_total_tokens: u32, window_size: Option, generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, max_input_length, 16, window_size); + let queue = Queue::new( + requires_padding, + max_input_length, + max_total_tokens, + 16, + window_size + ); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/queue.rs b/router/src/queue.rs index 4e6e99a8..9961d27e 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -37,6 +37,7 @@ impl Queue { pub(crate) fn new( requires_padding: bool, max_input_length: u32, + max_total_tokens: u32, block_size: u32, window_size: Option ) -> Self { @@ -47,6 +48,7 @@ impl Queue { tokio::spawn(queue_task( requires_padding, max_input_length, + max_total_tokens, block_size, window_size, queue_receiver, @@ -96,11 +98,18 @@ impl Queue { async fn queue_task( requires_padding: bool, max_input_length: u32, + max_total_tokens: u32, block_size: u32, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - let mut state = State::new(requires_padding, max_input_length, block_size, window_size); + let mut state = State::new( + requires_padding, + max_input_length, + max_total_tokens, + block_size, + window_size + ); while let Some(cmd) = receiver.recv().await { match cmd { @@ -138,9 +147,12 @@ struct State { /// Whether the model is using padding requires_padding: bool, - /// Maximum inpult length, required for padding scenario + /// Maximum input length, required for padding scenario max_input_length: u32, + /// Maximum input and output length, required for padding scenario + max_total_tokens: u32, + /// Paged Attention block size block_size: u32, @@ -152,6 +164,7 @@ impl State { fn new( requires_padding: bool, max_input_length: u32, + max_total_tokens: u32, block_size: u32, window_size: Option ) -> Self { @@ -161,6 +174,7 @@ impl State { next_batch_id: 0, requires_padding, max_input_length, + max_total_tokens, block_size, window_size, } @@ -218,7 +232,7 @@ impl State { if self.requires_padding { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation - prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length + prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_length; } else { // pad to block size prefill_tokens += ((entry.request.input_length + self.block_size - 1) @@ -227,7 +241,9 @@ impl State { } if self.requires_padding { - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + // We pad to max total tokens in the Python shards + // We need to take these padding tokens into the equation + decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_length); } else { let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, diff --git a/router/src/server.rs b/router/src/server.rs index 97bc20c2..ca339490 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -596,6 +596,7 @@ pub async fn run( max_concurrent_requests, shard_info.requires_padding, max_input_length as u32, + max_total_tokens as u32, shard_info.window_size, generation_health, ); diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8630bbd1..cf6f2a1e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -36,6 +36,7 @@ from loguru import logger tracer = trace.get_tracer(__name__) BATCH_BUCKET_SIZE = int(os.environ.get('BATCH_BUCKET_SIZE', 8)) +PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get('PREFILL_BATCH_BUCKET_SIZE', 4)) TRACE_FILENAME = os.environ.get('TRACE_FILENAME') def trace(txt): @@ -234,7 +235,11 @@ class CausalLMBatch(Batch): top_n_tokens = [r.data.top_n_tokens for r in requests] top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.data.parameters for r in requests], batches[0].next_token_chooser.device, batches[0].next_token_chooser.dtype) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + [r.data.parameters for r in requests], + batches[0].next_token_chooser.device, + batches[0].next_token_chooser.dtype + ) htorch.core.mark_step() @@ -286,7 +291,7 @@ class CausalLMBatch(Batch): # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out - new_bs = round_up(len(requests), BATCH_BUCKET_SIZE) + new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) dummy_inputs = ["?"] * (new_bs - len(requests)) tokenized_inputs = tokenizer( [r.data.inputs for r in requests] + dummy_inputs,