From c94f415af4ccfba3321c152589bc3046f272ed88 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Sat, 10 May 2025 21:41:18 +0800 Subject: [PATCH] Change HPU warmup logic: seq length should be with exponential growth (#3217) Signed-off-by: Liu, Kaixuan Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- .../models/causal_lm.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 374b6fd6..b501d488 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -56,6 +56,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2)) +SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2)) MAX_BATCH_SIZE = ( int(os.environ.get("MAX_BATCH_SIZE")) if os.environ.get("MAX_BATCH_SIZE") is not None @@ -71,8 +72,21 @@ def torch_compile_for_eager(func): ) -def round_up_seq(number, k): - return (number + k - 1) // k * k +def round_up_seq(number, k, base): + exponent = math.ceil(math.log(number / k, base)) + return int(k * (base**exponent)) + + +def iterate_powers_of_base(max_value, start, base): + current = start + result = [] + assert ( + max_value >= start + ), f"max_value {max_value} must be greater than start {start}" + while current < max_value: + result.append(current) + current *= base + return result def round_up_batch(number): @@ -575,7 +589,9 @@ class CausalLMBatch(Batch): assert ( PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length ), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" - rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) + rounded_seq_len = round_up_seq( + input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE + ) if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 else: @@ -1345,14 +1361,9 @@ class CausalLM(Model): max_exp + 1, ) ] - prefill_seqlen_list = [ - seq - for seq in range( - PAD_SEQUENCE_TO_MULTIPLE_OF, - max_input_tokens, - PAD_SEQUENCE_TO_MULTIPLE_OF, - ) - ] + prefill_seqlen_list = iterate_powers_of_base( + max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE + ) prefill_seqlen_list.append(max_input_tokens) prefill_batch_size_list.sort(reverse=True) prefill_seqlen_list.sort(reverse=True)