diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 2c8662eb..66a17877 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -14,6 +14,7 @@ from typing import Optional, Tuple, List import torch from torch import nn +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( paged_attention, attention, @@ -274,6 +275,10 @@ class Qwen3Model(nn.Module): ) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, @@ -286,6 +291,8 @@ class Qwen3Model(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states)