Refine logging for Gaudi warmup

This commit is contained in:
regisss 2025-05-10 12:59:36 +00:00
parent 56c8189467
commit 2b2b4a814d
2 changed files with 22 additions and 14 deletions

View File

@ -1356,9 +1356,16 @@ class CausalLM(Model):
prefill_seqlen_list.append(max_input_tokens) prefill_seqlen_list.append(max_input_tokens)
prefill_batch_size_list.sort(reverse=True) prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True) prefill_seqlen_list.sort(reverse=True)
logger.info(
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
)
try: try:
for batch_size in prefill_batch_size_list: for batch_size in prefill_batch_size_list:
for seq_len in prefill_seqlen_list: for seq_len in prefill_seqlen_list:
logger.info(
f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..."
)
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size) batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch]) _, prefill_batch, _ = self.generate_token([batch])
except Exception: except Exception:
@ -1374,9 +1381,7 @@ class CausalLM(Model):
prefill_batch_size_list.sort() prefill_batch_size_list.sort()
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing prefill warmup successfully.\n" f"Prefill warmup successful.\n"
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
@ -1386,9 +1391,11 @@ class CausalLM(Model):
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1) BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
] ]
decode_batch_size_list.sort(reverse=True) decode_batch_size_list.sort(reverse=True)
logger.info(f"Decode batch size list:{decode_batch_size_list}\n")
try: try:
for batch_size in decode_batch_size_list: for batch_size in decode_batch_size_list:
logger.info(f"Decode warmup for `batch_size={batch_size}`, this may take a while...")
batches = [] batches = []
iters = math.floor(batch_size / max_prefill_batch_size) iters = math.floor(batch_size / max_prefill_batch_size)
for i in range(iters): for i in range(iters):
@ -1422,8 +1429,7 @@ class CausalLM(Model):
max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1] max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing decode warmup successfully.\n" f"Decode warmup successful.\n"
f"Decode batch size list:{decode_batch_size_list}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )

View File

@ -1511,9 +1511,16 @@ class VlmCausalLM(Model):
DECODE_WARMUP_BATCH_SIZE_LIST = [] DECODE_WARMUP_BATCH_SIZE_LIST = []
prefill_batch = None prefill_batch = None
decode_batch = None decode_batch = None
logger.info(
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
)
try: try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST: for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST:
for seq_len in PREFILL_WARMUP_SEQLEN_LIST: for seq_len in PREFILL_WARMUP_SEQLEN_LIST:
logger.info(
f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..."
)
batch = self.generate_warmup_batch( batch = self.generate_warmup_batch(
request, seq_len, batch_size, is_warmup=True request, seq_len, batch_size, is_warmup=True
) )
@ -1528,23 +1535,18 @@ class VlmCausalLM(Model):
except Exception: except Exception:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup." f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) )
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing prefill and decode warmup successfully.\n" f"Prefill warmup successful.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats} " f"Memory stats: {mem_stats} "
) )
max_decode_batch_size = MAX_BATCH_SIZE max_decode_batch_size = MAX_BATCH_SIZE
batch_size = max_prefill_batch_size * 2 batch_size = max_prefill_batch_size * 2
logger.info(f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n")
# Decode warmup with bigger batch_size # Decode warmup with bigger batch_size
try: try:
if ( if (
@ -1554,6 +1556,7 @@ class VlmCausalLM(Model):
batches = [] batches = []
while batch_size <= max_decode_batch_size: while batch_size <= max_decode_batch_size:
for i in range(int(batch_size / max_prefill_batch_size)): for i in range(int(batch_size / max_prefill_batch_size)):
logger.info(f"Decode warmup for `batch_size={batch_size}`, this may take a while...")
batch = self.generate_warmup_batch( batch = self.generate_warmup_batch(
request, request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1, PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
@ -1597,8 +1600,7 @@ class VlmCausalLM(Model):
mem_stats = get_hpu_memory_stats(self.device) mem_stats = get_hpu_memory_stats(self.device)
logger.info( logger.info(
f"\nFollowing decode warmup successfully.\n" f"Decode warmup successful.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats}" f"Memory stats: {mem_stats}"
) )