[gaudi] Refine logging for Gaudi warmup (#3222)

* Refine logging for Gaudi warmup

* Make style

* Make style 2

* Flash causal LM case

* Add log_master & VLM cases

* Black
This commit is contained in:
regisss 2025-06-18 04:34:00 -06:00 committed by GitHub
parent b4d17f18ff
commit f13e28c98d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 37 additions and 7 deletions

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
"""PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics2 model."""
"""PyTorch Idefics2 model."""
from typing import List, Optional, Tuple

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics3 model."""
"""PyTorch Idefics3 model."""
from typing import List, Optional, Tuple

View File

@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Idefics model configuration"""
"""Idefics model configuration"""
import copy
from transformers import PretrainedConfig

View File

@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics model."""
"""PyTorch Idefics model."""
from typing import List, Optional, Tuple, Union
import torch

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
"""PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
from dataclasses import dataclass

View File

@ -1721,6 +1721,7 @@ class FlashCausalLM(Model):
f"{dim}:{seq_len} "
f"bypass:{bypass} "
f"free_mem:{free_mem}"
", this may take a while..."
)
log_master(logger.info, msg)
@ -1772,6 +1773,11 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
@ -1798,6 +1804,8 @@ class FlashCausalLM(Model):
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
@ -1809,6 +1817,9 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
@ -1833,6 +1844,8 @@ class FlashCausalLM(Model):
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -822,6 +822,9 @@ class FlashVlmCausalLM(FlashCausalLM):
total_batch_seq = 0.001
total_mem = 0
available_mem = decode_available_memory
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
@ -847,6 +850,8 @@ class FlashVlmCausalLM(FlashCausalLM):
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -398,6 +398,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
@ -424,6 +429,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
@ -435,6 +442,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
@ -459,6 +469,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -8,7 +8,7 @@ import torch
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
adapter_indices: Union[torch.Tensor, List[int]],
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []