mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add mark_step in vlm part
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
d68edc4a2f
commit
f72b290020
@ -38,6 +38,7 @@ from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
def _prepare_aspect_ratio_attention_mask(
|
||||
@ -320,6 +321,9 @@ class MllamaVisionEncoder(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
encoder_states = [hidden_states]
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
@ -328,6 +332,8 @@ class MllamaVisionEncoder(nn.Module):
|
||||
|
||||
hidden_states = layer_outputs
|
||||
encoder_states.append(hidden_states)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
return hidden_states, encoder_states
|
||||
|
||||
|
@ -49,6 +49,7 @@ from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
|
||||
from typing import Union
|
||||
@ -595,7 +596,7 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
@ -745,6 +746,9 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
|
||||
# iterately apply the blocks to the hidden states
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for layer_num, block in enumerate(self.blocks):
|
||||
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
|
||||
# that are applied at specific layers.
|
||||
@ -754,6 +758,8 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -48,6 +48,7 @@ from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
@ -330,8 +331,13 @@ class Qwen2VisionModel(nn.Module):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
# iterately apply the blocks to the hidden states
|
||||
lazy_mode = htorch.utils.internal.is_lazy()
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
|
||||
if lazy_mode:
|
||||
htorch.core.mark_step()
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
Loading…
Reference in New Issue
Block a user