[gaudi] Refine logging for Gaudi warmup (#3222)

* Refine logging for Gaudi warmup

* Make style

* Make style 2

* Flash causal LM case

* Add log_master & VLM cases

* Black
This commit is contained in:
regisss 2025-06-18 04:34:00 -06:00 committed by GitHub
parent b4d17f18ff
commit f13e28c98d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 37 additions and 7 deletions

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Idefics2 model.""" """PyTorch Idefics2 model."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Idefics3 model.""" """PyTorch Idefics3 model."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple

View File

@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Idefics model configuration""" """Idefics model configuration"""
import copy import copy
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Idefics model.""" """PyTorch Idefics model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object""" """PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -1721,6 +1721,7 @@ class FlashCausalLM(Model):
f"{dim}:{seq_len} " f"{dim}:{seq_len} "
f"bypass:{bypass} " f"bypass:{bypass} "
f"free_mem:{free_mem}" f"free_mem:{free_mem}"
", this may take a while..."
) )
log_master(logger.info, msg) log_master(logger.info, msg)
@ -1772,6 +1773,11 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = prompt_available_memory available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets): for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
@ -1798,6 +1804,8 @@ class FlashCausalLM(Model):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b): def ordering_function_max_bs(b):
return (-b[0], b[1]) return (-b[0], b[1])
@ -1809,6 +1817,9 @@ class FlashCausalLM(Model):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = free_mem - self.mem_reserved available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -1833,6 +1844,8 @@ class FlashCausalLM(Model):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -822,6 +822,9 @@ class FlashVlmCausalLM(FlashCausalLM):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = decode_available_memory available_mem = decode_available_memory
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -847,6 +850,8 @@ class FlashVlmCausalLM(FlashCausalLM):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -398,6 +398,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = prompt_available_memory available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets): for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens: if batch_size * seq_len > self.max_batch_prefill_tokens:
continue continue
@ -424,6 +429,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b): def ordering_function_max_bs(b):
return (-b[0], b[1]) return (-b[0], b[1])
@ -435,6 +442,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_batch_seq = 0.001 total_batch_seq = 0.001
total_mem = 0 total_mem = 0
available_mem = free_mem - self.mem_reserved available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets): for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num: if batch_size > block_num:
continue continue
@ -459,6 +469,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
total_mem += used_mem total_mem += used_mem
total_batch_seq += batch_seq total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master( log_master(
logger.info, logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}", f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",

View File

@ -8,7 +8,7 @@ import torch
def find_segments( def find_segments(
adapter_indices: Union[torch.Tensor, List[int]] adapter_indices: Union[torch.Tensor, List[int]],
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
segments = [0] segments = [0]
segment_indices = [] segment_indices = []