Refine logging for Gaudi warmup

This commit is contained in:
regisss 2025-05-10 12:59:36 +00:00
parent 56c8189467
commit 2b2b4a814d
2 changed files with 22 additions and 14 deletions

View File

@ -1356,9 +1356,16 @@ class CausalLM(Model):
prefill_seqlen_list.append(max_input_tokens)
prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True)
logger.info(
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
)
try:
for batch_size in prefill_batch_size_list:
for seq_len in prefill_seqlen_list:
logger.info(
f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..."
)
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch])
except Exception:
@ -1374,9 +1381,7 @@ class CausalLM(Model):
prefill_batch_size_list.sort()
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill warmup successfully.\n"
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Prefill warmup successful.\n"
f"Memory stats: {mem_stats} "
)
@ -1386,9 +1391,11 @@ class CausalLM(Model):
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
]
decode_batch_size_list.sort(reverse=True)
logger.info(f"Decode batch size list:{decode_batch_size_list}\n")
try:
for batch_size in decode_batch_size_list:
logger.info(f"Decode warmup for `batch_size={batch_size}`, this may take a while...")
batches = []
iters = math.floor(batch_size / max_prefill_batch_size)
for i in range(iters):
@ -1422,8 +1429,7 @@ class CausalLM(Model):
max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{decode_batch_size_list}\n"
f"Decode warmup successful.\n"
f"Memory stats: {mem_stats} "
)

View File

@ -1511,9 +1511,16 @@ class VlmCausalLM(Model):
DECODE_WARMUP_BATCH_SIZE_LIST = []
prefill_batch = None
decode_batch = None
logger.info(
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
)
try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST:
for seq_len in PREFILL_WARMUP_SEQLEN_LIST:
logger.info(
f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..."
)
batch = self.generate_warmup_batch(
request, seq_len, batch_size, is_warmup=True
)
@ -1528,23 +1535,18 @@ class VlmCausalLM(Model):
except Exception:
raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"You need to decrease `--max-batch-prefill-tokens`"
)
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill warmup successful.\n"
f"Memory stats: {mem_stats} "
)
max_decode_batch_size = MAX_BATCH_SIZE
batch_size = max_prefill_batch_size * 2
logger.info(f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n")
# Decode warmup with bigger batch_size
try:
if (
@ -1554,6 +1556,7 @@ class VlmCausalLM(Model):
batches = []
while batch_size <= max_decode_batch_size:
for i in range(int(batch_size / max_prefill_batch_size)):
logger.info(f"Decode warmup for `batch_size={batch_size}`, this may take a while...")
batch = self.generate_warmup_batch(
request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
@ -1597,8 +1600,7 @@ class VlmCausalLM(Model):
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Decode warmup successful.\n"
f"Memory stats: {mem_stats}"
)