Change HPU warmup logic: seq length should be with exponential growth (#3217)

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
This commit is contained in:
kaixuanliu 2025-05-10 21:41:18 +08:00 committed by GitHub
parent 56c8189467
commit c94f415af4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,6 +56,7 @@ PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF",
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
SEQ_LEN_EXPONENT_BASE = int(os.environ.get("SEQ_LEN_EXPONENT_BASE", 2))
MAX_BATCH_SIZE = (
int(os.environ.get("MAX_BATCH_SIZE"))
if os.environ.get("MAX_BATCH_SIZE") is not None
@ -71,8 +72,21 @@ def torch_compile_for_eager(func):
)
def round_up_seq(number, k):
return (number + k - 1) // k * k
def round_up_seq(number, k, base):
exponent = math.ceil(math.log(number / k, base))
return int(k * (base**exponent))
def iterate_powers_of_base(max_value, start, base):
current = start
result = []
assert (
max_value >= start
), f"max_value {max_value} must be greater than start {start}"
while current < max_value:
result.append(current)
current *= base
return result
def round_up_batch(number):
@ -575,7 +589,9 @@ class CausalLMBatch(Batch):
assert (
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
rounded_seq_len = round_up_seq(
input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
)
if rounded_seq_len <= max_input_length:
bucket_size = rounded_seq_len - 1
else:
@ -1345,14 +1361,9 @@ class CausalLM(Model):
max_exp + 1,
)
]
prefill_seqlen_list = [
seq
for seq in range(
PAD_SEQUENCE_TO_MULTIPLE_OF,
max_input_tokens,
PAD_SEQUENCE_TO_MULTIPLE_OF,
)
]
prefill_seqlen_list = iterate_powers_of_base(
max_input_tokens, PAD_SEQUENCE_TO_MULTIPLE_OF, SEQ_LEN_EXPONENT_BASE
)
prefill_seqlen_list.append(max_input_tokens)
prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True)