Simplify the warmup

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-10-24 06:26:48 +00:00
parent 8686a0fc6d
commit 8ebe77b3be

View File

@ -1213,10 +1213,9 @@ class CausalLM(Model):
#warmup decode batch size
max_prefill_batch_size = batch.input_ids.shape[0]
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
batch_size = max_decode_batch_size
self.limit_hpu_graph = True
try:
while batch_size > 1:
for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE):
batches= []
iters = math.floor(batch_size/max_prefill_batch_size)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
@ -1231,8 +1230,6 @@ class CausalLM(Model):
batches.append(batch)
_, decode_batch, _ = self.generate_token(batches, is_warmup)
logger.info(f"DECODE_DIVISOR={BATCH_BUCKET_SIZE}")
batch_size = math.floor(batch_size / BATCH_BUCKET_SIZE)
except:
DECODE_WARMUP_BATCH_SIZE_LIST.pop(-1)
self.model.clear_cache()
@ -1257,26 +1254,21 @@ class CausalLM(Model):
# Warmup prefill batch_size
max_input_length = request.max_input_length
max_prefill_batch_size = batch.input_ids.shape[0]
batch_size = max_prefill_batch_size
while batch_size >= 1:
PREFILL_WARMUP_BATCH_SIZE_LIST.append(batch_size)
batch_size = math.floor(batch_size / PREFILL_BATCH_BUCKET_SIZE)
seq_len = PAD_SEQUENCE_TO_MULTIPLE_OF
seq_len = max_input_length
while seq_len >= PAD_SEQUENCE_TO_MULTIPLE_OF:
i = 0
while seq_len <= max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(seq_len)
seq_len = math.floor(seq_len/2)
seq_len += PAD_SEQUENCE_TO_MULTIPLE_OF*(2**i)
i += 1
if PREFILL_WARMUP_SEQLEN_LIST[-1] > PAD_SEQUENCE_TO_MULTIPLE_OF:
PREFILL_WARMUP_SEQLEN_LIST.append(PAD_SEQUENCE_TO_MULTIPLE_OF)
if PREFILL_WARMUP_SEQLEN_LIST[-1] < max_input_length:
PREFILL_WARMUP_SEQLEN_LIST.append(max_input_length)
#Prefill and decode warmup
prefill_batch = None
PREFILL_WARMUP_BATCH_SIZE_LIST.sort()
PREFILL_WARMUP_SEQLEN_LIST.sort()
try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST :
for batch_size in range(PREFILL_BATCH_BUCKET_SIZE, max_prefill_batch_size, PREFILL_BATCH_BUCKET_SIZE):
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
for seq_len in PREFILL_WARMUP_SEQLEN_LIST :
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch], is_warmup)