Add mark_step into qwen3

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-22 07:17:49 +03:00
parent 3d20c79007
commit ad41abd68c

View File

@ -14,6 +14,7 @@ from typing import Optional, Tuple, List
import torch import torch
from torch import nn from torch import nn
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -274,6 +275,10 @@ class Qwen3Model(nn.Module):
) )
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, decoder_layer in enumerate(self.layers): for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer( hidden_states = decoder_layer(
hidden_states, hidden_states,
@ -286,6 +291,8 @@ class Qwen3Model(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)