From 2e8d3e91ea67af97a08fab636e44d94576c78fa1 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Thu, 22 May 2025 07:20:21 +0300 Subject: [PATCH] Add mark_step into llama4 Signed-off-by: yuanwu --- .../models/custom_modeling/flash_llama4_modeling.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 98994e48..11864c52 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn import torch.nn.functional as F +import habana_frameworks.torch as htorch from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module): ) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( @@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module): position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states)