Add mark_step into llama4

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-22 07:20:21 +03:00
parent ad41abd68c
commit 2e8d3e91ea

View File

@ -22,6 +22,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import habana_frameworks.torch as htorch
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module):
) )
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states = layer( hidden_states = layer(
@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module):
position_ids=position_ids, position_ids=position_ids,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)