From ad41abd68caa71230ae79e5c2196530acbe7d741 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Thu, 22 May 2025 07:17:49 +0300 Subject: [PATCH] Add mark_step into qwen3 Signed-off-by: yuanwu --- .../models/custom_modeling/flash_qwen3_modeling.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 2c8662eb..66a17877 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -14,6 +14,7 @@ from typing import Optional, Tuple, List import torch from torch import nn +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( paged_attention, attention, @@ -274,6 +275,10 @@ class Qwen3Model(nn.Module): ) residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, @@ -286,6 +291,8 @@ class Qwen3Model(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states)