add mark_step in vlm part

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-10 19:02:14 -07:00
parent d68edc4a2f
commit f72b290020
3 changed files with 19 additions and 1 deletions

View File

@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
) )
from habana_frameworks.torch.hpex.kernels import FusedSDPA from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA
import habana_frameworks.torch as htorch
def _prepare_aspect_ratio_attention_mask( def _prepare_aspect_ratio_attention_mask(
@ -320,6 +321,9 @@ class MllamaVisionEncoder(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
): ):
encoder_states = [hidden_states] encoder_states = [hidden_states]
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
@ -328,6 +332,8 @@ class MllamaVisionEncoder(nn.Module):
hidden_states = layer_outputs hidden_states = layer_outputs
encoder_states.append(hidden_states) encoder_states.append(hidden_states)
if lazy_mode:
htorch.core.mark_step()
return hidden_states, encoder_states return hidden_states, encoder_states

View File

@ -49,6 +49,7 @@ from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, RotaryPosEmbeddingMode,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union from typing import Union
@ -595,7 +596,7 @@ class Qwen2_5VisionModel(nn.Module):
config=config, config=config,
weights=weights, weights=weights,
) )
# import ipdb; ipdb.set_trace()
self.temporal_patch_size = config.temporal_patch_size self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels self.in_channels = config.in_channels
@ -745,6 +746,9 @@ class Qwen2_5VisionModel(nn.Module):
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states # iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for layer_num, block in enumerate(self.blocks): for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks # NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers. # that are applied at specific layers.
@ -754,6 +758,8 @@ class Qwen2_5VisionModel(nn.Module):
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen) hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states # apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)

View File

@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, RotaryPosEmbeddingMode,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
import habana_frameworks.torch as htorch
class Qwen2VLAttention(nn.Module): class Qwen2VLAttention(nn.Module):
@ -330,8 +331,13 @@ class Qwen2VisionModel(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states # iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for block in self.blocks: for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen) hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states # apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)