diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 98994e48..11864c52 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn import torch.nn.functional as F +import habana_frameworks.torch as htorch from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module): ) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( @@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module): position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states)