[gaudi] Refine logging for Gaudi warmup (#3222)

* Refine logging for Gaudi warmup

* Make style

* Make style 2

* Flash causal LM case

* Add log_master & VLM cases

* Black
This commit is contained in:
regisss 2025-06-18 04:34:00 -06:00 committed by GitHub
parent b4d17f18ff
commit f13e28c98d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 37 additions and 7 deletions

View File

@ -1721,6 +1721,7 @@ class FlashCausalLM(Model):
f"{dim}:{seq_len} " f"{dim}:{seq_len} "
f"bypass:{bypass} " f"bypass:{bypass} "
f"free_mem:{free_mem}" f"free_mem:{free_mem}"
", this may take a while..."
) )
log_master(logger.info, msg) log_master(logger.info, msg)
@ -1772,6 +1773,11 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = prompt_available_memory available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets): for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
@ -1798,6 +1804,8 @@ class FlashCausalLM(Model):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b): def ordering_function_max_bs(b):
return (-b[0], b[1]) return (-b[0], b[1])
@ -1809,6 +1817,9 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = free_mem - self.mem_reserved available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -1833,6 +1844,8 @@ class FlashCausalLM(Model):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -822,6 +822,9 @@ class FlashVlmCausalLM(FlashCausalLM):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = decode_available_memory available_mem = decode_available_memory
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -847,6 +850,8 @@ class FlashVlmCausalLM(FlashCausalLM):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -398,6 +398,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = prompt_available_memory available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets): for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
@ -424,6 +429,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b): def ordering_function_max_bs(b):
return (-b[0], b[1]) return (-b[0], b[1])
@ -435,6 +442,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = free_mem - self.mem_reserved available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -459,6 +469,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -8,7 +8,7 @@ import torch
def find_segments( def find_segments(
adapter_indices: Union[torch.Tensor, List[int]] adapter_indices: Union[torch.Tensor, List[int]],
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
segments = [0] segments = [0]
segment_indices = [] segment_indices = []