# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2 VL model."""

from typing import Optional, Tuple, List

import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM == "ipex":
    import intel_extension_for_pytorch as ipex
else:
    import flash_attn_2_cuda

import numpy as np

from transformers.activations import ACT2FN
import torch.nn.functional as F

from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
from text_generation_server.layers import (
    TensorParallelColumnLinear,
    TensorParallelRowLinear,
    TensorParallelEmbedding,
    SpeculativeHead,
)
from text_generation_server.layers.attention import (
    Seqlen,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
    Qwen2Model,
)


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_vision(
    tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
    orig_dtype = tensor.dtype
    tensor = tensor.float()
    cos = freqs.cos()
    sin = freqs.sin()
    cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
    sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
    output = (tensor * cos) + (rotate_half(tensor) * sin)
    output = output.to(orig_dtype)
    return output


class Qwen2VLAttention(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.embed_dim = config.embed_dim // weights.process_group.size()
        self.head_dim = config.hidden_size // config.num_heads
        self.num_heads = config.num_heads // weights.process_group.size()

        self.qkv = TensorParallelColumnLinear.load_qkv(
            config,
            prefix=f"{prefix}.qkv",
            weights=weights,
            bias=False,
            num_heads=self.num_heads,
            num_key_value_heads=self.num_heads,
        )
        self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
        self.proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.proj",
            weights=weights,
            bias=True,
        )
        self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)

    def forward(
        self,
        hidden_state: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
        max_seqlen: int,
    ) -> torch.Tensor:
        # apply the qkv linear layer to the hidden state
        qkv = self.qkv(hidden_state)
        query, key, value = qkv.split(
            [self.embed_dim, self.embed_dim, self.embed_dim], dim=1
        )

        # reshape the query, key, and value tensors
        _shape = (
            hidden_state.shape[0],
            self.num_heads,
            self.embed_dim // self.num_heads,
        )
        query = query.view(*_shape)
        key = key.view(*_shape)
        value = value.view(*_shape)

        # apply rotary positional embeddings
        query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
            0
        )
        key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)

        # calc maximum sequence length for any batch
        query = query.contiguous()
        key = key.contiguous()
        value = value.contiguous()
        causal = False

        # execute flash attention
        if SYSTEM == "ipex":
            attn_output = torch.empty_like(query)
            ipex.llm.functional.varlen_attention(
                (query.contiguous() if query.device.type == "xpu" else query),
                (key.contiguous() if key.device.type == "xpu" else key),
                (value.contiguous() if value.device.type == "xpu" else value),
                attn_output,
                cu_seqlens,
                cu_seqlens,
                max_seqlen,
                max_seqlen,
                0.0,
                self.softmax_scale,
                False,
                causal,
                False,
                None,
            )
        else:
            attn_output = flash_attn_2_cuda.varlen_fwd(
                query,
                key,
                value,
                None,  # tmp buffer (auto-allocated)
                cu_seqlens,  # cu_seqlens_q
                cu_seqlens,  # cu_seqlens_k
                None,  # max_seqlen_q (auto-computed)
                None,  # max_seqlen_k (auto-computed)
                None,  # block_tables
                None,  # broadcast_mask
                max_seqlen,  # max_seqlen
                max_seqlen,  # max_seqlen
                0.0,  # dropout_p
                self.softmax_scale,
                False,  # zero_tensors
                causal,  # causal attention within each sequence
                -1,  # window_size_left
                -1,  # window_size_right
                0.0,  # softmax_cap
                False,  # deterministic
                None,  # rng_state
            )[0]

        # reshape output to original dimensions
        attn_output = attn_output.reshape(hidden_state.shape[0], -1)
        attn_output = self.proj(attn_output)
        return attn_output


class Qwen2VLVisionMLP(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = TensorParallelColumnLinear.load(
            prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
        )
        self.fc2 = TensorParallelRowLinear.load(
            prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Qwen2VLVisionBlock(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.attn = Qwen2VLAttention(
            prefix=f"{prefix}.attn",
            config=config,
            weights=weights,
        )
        self.norm1 = FastLayerNorm.load(
            prefix=f"{prefix}.norm1",
            weights=weights,
            eps=1e-6,
        )
        self.norm2 = FastLayerNorm.load(
            prefix=f"{prefix}.norm2",
            weights=weights,
            eps=1e-6,
        )
        self.mlp = Qwen2VLVisionMLP(
            prefix=f"{prefix}.mlp",
            config=config,
            weights=weights,
        )

    def forward(
        self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
    ) -> torch.Tensor:
        norm1_out, residual = self.norm1(hidden_states)
        attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
        hidden_states = attn_out + residual
        norm2_out, residual = self.norm2(hidden_states)
        hidden_states = hidden_states + self.mlp(norm2_out)
        return hidden_states


class Qwen2VLPatchMerger(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
        self.patch_merger_ln_q = FastLayerNorm.load(
            prefix=f"{prefix}.ln_q",
            weights=weights,
            eps=1e-6,
        )
        self.fc1 = TensorParallelColumnLinear.load(
            prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
        )
        self.fc2 = TensorParallelRowLinear.load(
            prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
        )

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.patch_merger_ln_q(hidden_states)
        hidden_states = hidden_states.view(-1, self.hidden_size)
        hidden_states = self.fc1(hidden_states)
        hidden_states = F.gelu(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Qwen2VisionModel(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.spatial_merge_size = config.spatial_merge_size
        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
        self.patch_embedding = nn.Conv3d(
            in_channels=config.in_chans,
            out_channels=config.embed_dim,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )
        self.patch_embedding.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
        )
        head_dim = config.embed_dim // config.num_heads
        # TODO: replace with static positional embeddings once implemented
        theta = 10000.0
        dim = head_dim // 2
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.blocks = nn.ModuleList(
            [
                Qwen2VLVisionBlock(
                    prefix=f"{prefix}.blocks.{i}",
                    config=config,
                    weights=weights,
                )
                for i in range(config.depth)
            ]
        )
        self.merger = Qwen2VLPatchMerger(
            prefix=f"{prefix}.merger",
            config=config,
            weights=weights,
        )

        self.temporal_patch_size = config.temporal_patch_size
        self.spatial_patch_size = config.spatial_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.embed_dim

    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        # reshape the input tensor for processing
        shape = (
            -1,
            self.in_channels,
            self.temporal_patch_size,
            self.spatial_patch_size,
            self.spatial_patch_size,
        )
        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
        # TODO: revisit to see if we can avoid some of these reshapes

        # find the position ids for the input tensor based on the grid_thw
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))

        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()

        # apply the positional embeddings to the position ids
        seq = torch.arange(
            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)

        # create a cu_seqlens tensor to be used in the attention mask
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
        max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
        # iterately apply the blocks to the hidden states
        for block in self.blocks:
            hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)

        # apply the final patch merger to the hidden states
        hidden_states = self.merger(hidden_states)
        return hidden_states


class Qwen2VLForConditionalGeneration(nn.Module):
    def __init__(self, prefix, config, weights):
        super().__init__()
        self.config = config
        config.vision_config.quantize = None
        config.vision_config.speculator = config.speculator
        # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
        # returns rope_scaling.type == "default" for Qwen2-VL model at the moment
        if (
            hasattr(config, "rope_scaling")
            and config.rope_scaling is not None
            and config.rope_scaling.get("type", None) == "default"
        ):
            config.rope_scaling.update({"rope_type": "mrope"})
        self.hidden_size = config.hidden_size
        self.vision_start_token_id = config.vision_start_token_id
        self.vision_end_token_id = config.vision_end_token_id
        self.image_token_id = config.image_token_id
        self.video_token_id = config.video_token_id
        self.spatial_merge_size = config.vision_config.spatial_merge_size
        self.embed_tokens = TensorParallelEmbedding(
            prefix="model.embed_tokens", weights=weights
        )
        self.visual = Qwen2VisionModel(
            prefix="visual", config=config.vision_config, weights=weights
        )
        self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
        if config.tie_word_embeddings:
            suffix = "model.embed_tokens"
        else:
            suffix = "lm_head"

        self.lm_head = SpeculativeHead.load(
            config,
            prefix=suffix if not prefix else f"{prefix}.{suffix}",
            weights=weights,
        )
        self.norm = FastRMSNorm.load(
            prefix="model.norm",
            weights=weights,
            eps=config.rms_norm_eps,
        )
        self.device = weights.device

    # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
    # modified to first find segments then initialize position ids for each segment
    # Steps:
    #  locate all vision and text segments
    #  calculate `vision_segment_lengths` for each vision segment to be use as offset
    #  calculate `text_segment_lengths` for each text segment to be used as offset
    #  create position ids for each vision segment based on the image grid
    #  create position ids for each text segment
    #  combine all the position ids
    #  the final segment is the difference between the last vision segment and the end of the input
    #  combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
    def get_position_ids(
        self,
        input_ids: torch.Tensor,
        image_grid_thw: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if image_grid_thw is None:
            return (
                torch.arange(input_ids.shape[0], device=input_ids.device)
                .unsqueeze(1)
                .repeat(1, 3)
            )

        spatial_merge_size = self.spatial_merge_size
        vision_start_token_id = self.vision_start_token_id
        vision_end_token_id = self.vision_end_token_id
        device = input_ids.device
        dtype = input_ids.dtype
        input_ids_len = input_ids.shape[0]

        vision_starts = torch.where(input_ids == vision_start_token_id)[0]
        vision_ends = torch.where(input_ids == vision_end_token_id)[0]
        vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
        prev_vision_end = torch.cat(
            [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
        )
        text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
        vision_widths_max = torch.cat(
            [
                torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
                image_grid_thw[:-1, 2] // spatial_merge_size,
            ]
        )
        vision_segment_lengths = vision_widths_max + text_lengths_between_vision
        vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
        text_segment_lengths = vision_segment_lengths - text_lengths_between_vision

        # create position ids for each vision segment based on the image grid
        llm_pos_ids_list = []
        for i, _ in enumerate(vision_segments):
            t, h, w = (
                image_grid_thw[i][0],
                image_grid_thw[i][1] // spatial_merge_size,
                image_grid_thw[i][2] // spatial_merge_size,
            )
            t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
            h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
            w_indices = torch.arange(w, device=device).repeat(t * h)
            image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)

            # offset by the position of the last vision segment
            im = image_position_ids + vision_segment_lengths[i]
            llm_pos_ids_list.append(im)

        # create position ids for each text segment
        text_ranges = [
            torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
            + text_segment_lengths[i]
            for i, seq_len in enumerate(text_lengths_between_vision)
        ]

        full_llm_pos_ids_list = [
            item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
        ]
        max_s = full_llm_pos_ids_list[-1].max() + 1
        final_text_len = input_ids_len - vision_ends[-1]
        if final_text_len > 0:
            m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
            full_llm_pos_ids_list.append(m + max_s)

        position_ids = (
            torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
        )
        return position_ids

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        seqlen: Seqlen,
        max_s: int,
        prefill_cache_indices: Optional[torch.Tensor],
        lm_head_indices: Optional[torch.Tensor],
        pixel_values: torch.FloatTensor = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        pixel_attention_mask=None,
        image_sizes: Optional[torch.LongTensor] = None,
        adapter_data: Optional[torch.Tensor] = None,
        cross_attention_states: Optional[torch.Tensor] = None,
        image_indices=None,
    ):
        inputs_embeds = self.embed_tokens(input_ids)

        # apply the visual model to the pixel values if they are provided
        if pixel_values is not None and len(pixel_values) > 0:
            if pixel_values is not None:
                image_embeds = self.visual(
                    pixel_values, grid_thw=image_grid_thw
                ).squeeze(0)
                inputs_embeds[input_ids == self.image_token_id] = image_embeds

        hidden_states = self.text_model(
            inputs_embeds=inputs_embeds,
            position_ids=position_ids,
            cu_seqlen_prefill=cu_seqlen_prefill,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            seqlen=seqlen,
            max_s=max_s,
            true_max_s=max_s,
            prefill_cache_indices=prefill_cache_indices,
        )
        if lm_head_indices is not None:
            hidden_states = hidden_states[lm_head_indices]
        logits, speculative_logits = self.lm_head(hidden_states)
        return logits, speculative_logits