diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index 421a0a65..ea3129f9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import ( ) from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA +import habana_frameworks.torch as htorch def _prepare_aspect_ratio_attention_mask( @@ -320,6 +321,9 @@ class MllamaVisionEncoder(nn.Module): attention_mask: Optional[torch.Tensor] = None, ): encoder_states = [hidden_states] + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, @@ -328,6 +332,8 @@ class MllamaVisionEncoder(nn.Module): hidden_states = layer_outputs encoder_states.append(hidden_states) + if lazy_mode: + htorch.core.mark_step() return hidden_states, encoder_states diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 7cd651db..b0efbe1c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -49,6 +49,7 @@ from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py from typing import Union @@ -595,7 +596,7 @@ class Qwen2_5VisionModel(nn.Module): config=config, weights=weights, ) - # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size self.spatial_patch_size = config.spatial_patch_size self.in_channels = config.in_channels @@ -745,6 +746,9 @@ class Qwen2_5VisionModel(nn.Module): max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for layer_num, block in enumerate(self.blocks): # NOTE: qwen2_5_vl.py has a concept of full attention blocks # that are applied at specific layers. @@ -754,6 +758,8 @@ class Qwen2_5VisionModel(nn.Module): cu_seqlens_now = cu_window_seqlens hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen) + if lazy_mode: + htorch.core.mark_step() # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index d9c07f7d..652def22 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb, ) +import habana_frameworks.torch as htorch class Qwen2VLAttention(nn.Module): @@ -330,8 +331,13 @@ class Qwen2VisionModel(nn.Module): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for block in self.blocks: hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen) + if lazy_mode: + htorch.core.mark_step() # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states)