# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics2 model."""

from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
import math

from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
    load_text_model,
    load_vision_model,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask

from text_generation_server.layers import (
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Idefics2VisionEmbeddings(nn.Module):
    """
    This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
    resolution.

    The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
    which allows treating images in their native aspect ratio and without the need to resize them to the same
    fixed size. In particular, we start from the original pre-trained SigLIP model
    (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
    """

    def __init__(self, prefix, config, weights):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )
        self.patch_embedding.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
        )
        self.patch_embedding.bias = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
        )

        self.num_patches_per_side = self.image_size // self.patch_size
        self.num_patches = self.num_patches_per_side**2
        self.num_positions = self.num_patches
        self.position_embedding = TensorParallelEmbedding(
            prefix=f"{prefix}.position_embedding", weights=weights
        )

    def forward(
        self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
    ) -> torch.Tensor:
        batch_size, _, max_im_h, max_im_w = pixel_values.shape

        patch_embeds = self.patch_embedding(pixel_values)
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        max_nb_patches_h, max_nb_patches_w = (
            max_im_h // self.patch_size,
            max_im_w // self.patch_size,
        )
        boundaries = torch.arange(
            1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
        )
        position_ids = torch.full(
            size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
        )

        for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
            nb_patches_h = p_attn_mask[:, 0].sum()
            nb_patches_w = p_attn_mask[0].sum()

            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)

            bucket_coords_h = torch.bucketize(
                fractional_coords_h, boundaries, right=True
            )
            bucket_coords_w = torch.bucketize(
                fractional_coords_w, boundaries, right=True
            )

            pos_ids = (
                bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
            ).flatten()
            position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids

        position_ids = position_ids.to(self.position_embedding.weight.device)
        embeddings = embeddings + self.position_embedding(position_ids)
        return embeddings


class Idefics2VisionAttention(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_size = self.embed_dim // self.num_heads
        if self.head_size * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        self.scale = self.head_size**-0.5
        self.dropout = config.attention_dropout

        self.num_heads = self.num_heads // weights.process_group.size()
        self.embed_dim = self.embed_dim // weights.process_group.size()

        self.qkv = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
            dim=0,
            weights=weights,
            bias=True,
        )
        self.out_proj = TensorParallelRowLinear.load(
            config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
        )
        self.is_causal = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, q_len, _ = hidden_states.size()

        qkv = self.qkv(hidden_states)
        query_states, key_states, value_states = qkv.split(
            [
                self.head_size * self.num_heads,
                self.head_size * self.num_heads,
                self.head_size * self.num_heads,
            ],
            dim=2,
        )

        query_states = query_states.view(
            batch_size, q_len, self.num_heads, self.head_size
        ).transpose(1, 2)
        key_states = key_states.view(
            batch_size, q_len, self.num_heads, self.head_size
        ).transpose(1, 2)
        value_states = value_states.view(
            batch_size, q_len, self.num_heads, self.head_size
        ).transpose(1, 2)

        k_v_seq_len = key_states.shape[-2]
        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
        )

        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
            raise ValueError(
                f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output


class Idefics2VisionMLP(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = TensorParallelColumnLinear.load(
            prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
        )
        self.fc2 = TensorParallelRowLinear.load(
            prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Idefics2EncoderLayer(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = Idefics2VisionAttention(
            prefix=f"{prefix}.self_attn", config=config, weights=weights
        )
        self.layer_norm1 = nn.LayerNorm.load(
            prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
        )
        self.layer_norm2 = nn.LayerNorm.load(
            prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
        )
        self.mlp = Idefics2VisionMLP(
            prefix=f"{prefix}.mlp", config=config, weights=weights
        )

    # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class Idefics2Encoder(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList(
            [
                Idefics2EncoderLayer(
                    prefix=f"{prefix}.layers.{i}", config=config, weights=weights
                )
                for i in range(config.num_hidden_layers)
            ]
        )

    # Ignore copy
    def forward(
        self,
        inputs_embeds,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
            )
        return hidden_states


class Idefics2VisionTransformer(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.config = config
        self.embeddings = Idefics2VisionEmbeddings(
            prefix=f"{prefix}.embeddings", config=config, weights=weights
        )
        self.encoder = Idefics2Encoder(
            prefix=f"{prefix}.encoder", config=config, weights=weights
        )
        self.post_layernorm = nn.LayerNorm.load(
            prefix=f"{prefix}.post_layernorm",
            weights=weights,
            eps=config.layer_norm_eps,
        )

    def forward(
        self,
        pixel_values,
        patch_attention_mask: Optional[torch.BoolTensor] = None,
    ):
        batch_size = pixel_values.size(0)
        if patch_attention_mask is None:
            patch_size = self.config.patch_size
            patch_attention_mask = torch.ones(
                (
                    batch_size,
                    pixel_values.size(2) // patch_size,
                    pixel_values.size(3) // patch_size,
                )
            )
            patch_attention_mask = patch_attention_mask.to(
                dtype=torch.bool, device=pixel_values.device
            )

        hidden_states = self.embeddings(
            pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
        )

        patch_attention_mask = patch_attention_mask.view(batch_size, -1)
        # The call to `_upad_input` in `_flash_attention_forward` is expensive
        # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
        # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
        if not torch.any(~patch_attention_mask):
            patch_attention_mask = None
        else:
            patch_attention_mask = _prepare_4d_attention_mask(
                patch_attention_mask, hidden_states.dtype
            )

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            attention_mask=patch_attention_mask,
        )

        last_hidden_state = encoder_outputs
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Idefics2MLP(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        act = config.text_config.hidden_act
        self.act = (
            ACT2FN[act]
            if "gelu" not in act
            else lambda x: torch.nn.functional.gelu(
                x,
                approximate=(
                    "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
                ),
            )
        )
        self.gate_up_proj = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
            weights=weights,
            dim=0,
            bias=False,
        )
        self.down_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.down_proj",
            weights=weights,
            bias=False,
        )

    def forward(self, hidden_states):
        start_shape = hidden_states.shape[:-1]
        gate_up_states = self.gate_up_proj(hidden_states)
        intermediate_size = gate_up_states.shape[-1] // 2
        gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
        return self.down_proj(
            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
        ).view(*start_shape, -1)


class Idefics2RMSNorm(nn.Module):
    def __init__(self, prefix, weights, eps):
        """
        Idefics2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.weight"), requires_grad=False
        )
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class Idefics2PerceiverAttention(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()

        self.layer_idx = None
        self.hidden_size = config.text_config.hidden_size
        self.num_heads = config.perceiver_config.resampler_n_heads
        self.head_size = config.perceiver_config.resampler_head_dim
        self.num_key_value_heads = config.perceiver_config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.attention_dropout = config.perceiver_config.attention_dropout
        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            self.num_key_value_heads // weights.process_group.size()
        )

        self.q_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.q_proj",
            weights=weights,
            bias=False,
        )
        self.kv = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
            dim=0,
            weights=weights,
            bias=False,
        )
        self.o_proj = TensorParallelRowLinear.load(
            config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
        )

        self.is_causal = False

    def forward(
        self,
        latents: torch.Tensor,
        context: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = latents.size()
        kv_seq_len = q_len + context.size()[1]

        hidden_states = torch.concat([context, latents], dim=-2)
        query_states = self.q_proj(latents)
        kv = self.kv(hidden_states)
        key_states, value_states = kv.split(
            [
                self.head_size * self.num_key_value_heads,
                self.head_size * self.num_key_value_heads,
            ],
            dim=2,
        )

        query_states = query_states.view(
            bsz, q_len, self.num_heads, self.head_size
        ).transpose(1, 2)
        key_states = key_states.view(
            bsz, kv_seq_len, self.num_key_value_heads, self.head_size
        ).transpose(1, 2)
        value_states = value_states.view(
            bsz, kv_seq_len, self.num_key_value_heads, self.head_size
        ).transpose(1, 2)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_size)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )

            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)

        attn_output = self.o_proj(attn_output)

        return attn_output


class Idefics2PerceiverLayer(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.hidden_size = config.text_config.hidden_size
        self.n_latents = config.perceiver_config.resampler_n_latents
        self.depth = config.perceiver_config.resampler_depth
        self.rms_norm_eps = config.text_config.rms_norm_eps

        self.input_latents_norm = Idefics2RMSNorm(
            prefix=f"{prefix}.input_latents_norm",
            weights=weights,
            eps=self.rms_norm_eps,
        )
        self.input_context_norm = Idefics2RMSNorm(
            prefix=f"{prefix}.input_context_norm",
            weights=weights,
            eps=self.rms_norm_eps,
        )
        self.self_attn = Idefics2PerceiverAttention(
            prefix=f"{prefix}.self_attn", config=config, weights=weights
        )
        self.post_attention_layernorm = Idefics2RMSNorm(
            prefix=f"{prefix}.post_attention_layernorm",
            weights=weights,
            eps=self.rms_norm_eps,
        )
        self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

    def forward(
        self,
        latents: torch.Tensor,
        context: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
        """
        residual = latents

        latents = self.input_latents_norm(latents)
        context = self.input_context_norm(context)

        latents = self.self_attn(
            latents=latents,
            context=context,
            attention_mask=attention_mask,
        )
        latents = residual + latents
        residual = latents

        latents = self.post_attention_layernorm(latents)
        latents = self.mlp(latents)
        latents = residual + latents

        return latents


class Idefics2PerceiverResampler(nn.Module):
    def __init__(self, prefix, config, weights) -> None:
        super().__init__()
        self.hidden_size = config.text_config.hidden_size
        self.hidden_act = config.perceiver_config.hidden_act
        self.n_latents = config.perceiver_config.resampler_n_latents
        self.depth = config.perceiver_config.resampler_depth
        self.rms_norm_eps = config.text_config.rms_norm_eps

        # Create Latents for Perceiver
        self.latents = weights.get_tensor(f"{prefix}.latents")

        # Create Transformer Blocks
        self.layers = nn.ModuleList(
            [
                Idefics2PerceiverLayer(
                    prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
                )
                for idx in range(self.depth)
            ]
        )
        self.norm = Idefics2RMSNorm(
            prefix=f"{prefix}.norm",
            weights=weights,
            eps=config.text_config.rms_norm_eps,
        )

    def forward(
        self,
        context: torch.Tensor,
        attention_mask,
    ) -> torch.Tensor:
        # seq embed -> bsz seq embed
        latents = self.latents.unsqueeze(0).expand(
            (context.shape[0], *self.latents.size())
        )

        latent_attention_mask = torch.ones(
            (attention_mask.size(0), latents.size(1)),
            dtype=attention_mask.dtype,
            device=attention_mask.device,
        )
        attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
        attention_mask = _prepare_4d_attention_mask(
            attention_mask, latents.dtype, tgt_len=self.n_latents
        )

        compressed_context = latents
        for perceiver_layer in self.layers:
            compressed_context = perceiver_layer(
                compressed_context,
                context,
                attention_mask=attention_mask,
            )
        compressed_context = self.norm(compressed_context)

        return compressed_context


class Idefics2Connector(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.modality_projection = Idefics2MLP(
            prefix=f"{prefix}.modality_projection", config=config, weights=weights
        )
        self.perceiver_resampler = Idefics2PerceiverResampler(
            prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
        )

    def forward(self, image_hidden_states, attention_mask):
        image_hidden_states = self.modality_projection(image_hidden_states)
        image_hidden_states = self.perceiver_resampler(
            context=image_hidden_states, attention_mask=attention_mask
        )
        return image_hidden_states


class Idefics2ForConditionalGeneration(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        config.vision_config.quantize = config.quantize
        config.vision_config.speculator = config.speculator
        config.text_config.quantize = config.quantize
        config.text_config.speculator = config.speculator

        vision_config = config.vision_config
        self.text_model = load_text_model(
            prefix="model" if not prefix else f"{prefix}.model",
            config=config.text_config,
            weights=weights,
            name="text_model",
        )
        self.dtype = weights.dtype
        self.vision_model = Idefics2VisionTransformer(
            prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
            config=vision_config,
            weights=weights,
        )
        self.connector = Idefics2Connector(
            prefix=f"{prefix}.model.connector" if prefix else "model.connector",
            config=config,
            weights=weights,
        )
        self.config = config
        self.image_seq_len = config.perceiver_config.resampler_n_latents
        self.image_token_id = config.image_token_id
        self.pad_token_id = (
            config.pad_token_id if config.pad_token_id is not None else -1
        )

    def _merge_input_ids_with_image_features(
        self,
        input_ids: torch.Tensor,
        inputs_embeds: torch.Tensor,
        image_features: torch.Tensor,
    ):
        """In place merges in vision_embeddings with inputs_embeds."""
        # mask = input_ids == self.config.image_token_index
        mask = input_ids == self.config.image_token_id
        # Let's pray we have enabled enough slots !
        inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
        max_s: int,
        prefill_cache_indices: Optional[torch.Tensor],
        lm_head_indices: Optional[torch.Tensor] = None,
        pixel_values: torch.FloatTensor = None,
        pixel_attention_mask: Optional[torch.BoolTensor] = None,
        # Unused here
        image_sizes: Optional[torch.Tensor] = None,
    ):
        inputs_embeds = self.text_model.embed_tokens(input_ids)
        if pixel_values is not None:
            batch_size, num_images, num_channels, height, width = pixel_values.shape
            all_states = []
            all_pixel_values = pixel_values
            all_pixel_mask = pixel_attention_mask
            for i in range(batch_size):
                pixel_values = all_pixel_values.to(
                    dtype=self.dtype
                )  # fp16 compatibility
                pixel_values = pixel_values[i : i + 1]
                pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])

                # Remove padding images - padding images are full 0.
                nb_values_per_image = pixel_values.shape[1:].numel()
                real_images_inds = (pixel_values == 0.0).sum(
                    dim=(-1, -2, -3)
                ) != nb_values_per_image
                pixel_values = pixel_values[real_images_inds].contiguous()

                # Handle the vision attention mask
                if pixel_attention_mask is None:
                    pixel_attention_mask = torch.ones(
                        size=(
                            pixel_values.size(0),
                            pixel_values.size(2),
                            pixel_values.size(3),
                        ),
                        dtype=torch.bool,
                        device=pixel_values.device,
                    )
                else:
                    # Remove padding images from the mask/pP p
                    pixel_attention_mask = all_pixel_mask[i : i + 1]
                    pixel_attention_mask = pixel_attention_mask.view(
                        1 * num_images, *pixel_attention_mask.shape[2:]
                    )
                    pixel_attention_mask = pixel_attention_mask[
                        real_images_inds
                    ].contiguous()

                patch_size = self.config.vision_config.patch_size
                patches_subgrid = pixel_attention_mask.unfold(
                    dimension=1, size=patch_size, step=patch_size
                )
                patches_subgrid = patches_subgrid.unfold(
                    dimension=2, size=patch_size, step=patch_size
                )
                patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

                # Get sequence from the vision encoder
                image_hidden_states = self.vision_model(
                    pixel_values=pixel_values,
                    patch_attention_mask=patch_attention_mask,
                )

                # Modality projection & resampling
                image_hidden_states = self.connector(
                    image_hidden_states,
                    attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
                )
                all_states.append(image_hidden_states)
            image_hidden_states = torch.stack(all_states, dim=0)
            # When we generate, we don't want to replace the potential image_token_id that we generated by images
            # that simply don't exist
            inputs_embeds = self._merge_input_ids_with_image_features(
                input_ids, inputs_embeds, image_hidden_states
            )

        hidden_states = self.text_model.model(
            inputs_embeds=inputs_embeds,
            position_ids=position_ids,
            cu_seqlen_prefill=cu_seqlen_prefill,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_s=max_s,
            true_max_s=max_s,
            prefill_cache_indices=None,
        )
        if lm_head_indices is not None:
            hidden_states = hidden_states[lm_head_indices]
        logits, speculative_logits = self.text_model.lm_head(hidden_states)
        return logits, speculative_logits