"""A simple, flexible implementation of a GPT model.

Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""

import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from einops import rearrange
from packaging import version
from text_generation_server.layers import (
    TensorParallelEmbedding,
    TensorParallelColumnLinear,
    TensorParallelRowLinear,
    SpeculativeHead,
    get_linear,
)

EPS = 1e-5


def load_col(config, prefix, weights, bias):
    assert config.quantize != "gptq", NotImplementedError
    slice_ = weights._get_slice(f"{prefix}.weight")
    rank = weights.process_group.rank()
    size = weights.process_group.size()

    h3, h = slice_.get_shape()
    block_size = h // size

    q_part = slice_[rank * block_size : (rank + 1) * block_size]
    k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
    v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]

    weight = torch.cat([q_part, k_part, v_part], dim=0)
    if weight.dtype != torch.int32:
        weight = weight.to(dtype=weights.dtype)
    weight = weight.to(device=weights.device)

    if bias:
        bias_slice_ = weights._get_slice(f"{prefix}.bias")
        bias_rank = weights.process_group.rank()
        bias_size = weights.process_group.size()

        bias_h = bias_slice_.get_shape()
        bias_h = bias_h[0]
        bias_block_size = bias_h // bias_size

        bias_q_part = bias_slice_[
            bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
        ]
        bias_k_part = bias_slice_[
            bias_h
            + bias_rank * bias_block_size : bias_h
            + (bias_rank + 1) * bias_block_size
        ]
        bias_v_part = bias_slice_[
            2 * bias_h
            + bias_rank * bias_block_size : 2 * bias_h
            + (bias_rank + 1) * bias_block_size
        ]

        bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
        if bias.dtype != torch.int32:
            bias = bias.to(dtype=weights.dtype)
        bias = bias.to(device=weights.device)
    else:
        bias = None
    linear = get_linear(weight, bias)
    return TensorParallelColumnLinear(linear)


def _reset_is_causal(
    num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
):
    if original_is_causal and num_query_tokens != num_key_tokens:
        if num_query_tokens != 1:
            raise NotImplementedError(
                "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
            )
        else:
            return False
    return original_is_causal


def scaled_multihead_dot_product_attention(
    query,
    key,
    value,
    n_heads,
    past_key_value=None,
    softmax_scale=None,
    attn_bias=None,
    key_padding_mask=None,
    is_causal=False,
    dropout_p=0.0,
    training=False,
    needs_weights=False,
    multiquery=False,
):
    q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
    kv_n_heads = 1 if multiquery else n_heads
    k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
    v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
    if past_key_value is not None:
        if len(past_key_value) != 0:
            k = torch.cat([past_key_value[0], k], dim=3)
            v = torch.cat([past_key_value[1], v], dim=2)
        past_key_value = (k, v)
    (b, _, s_q, d) = q.shape
    s_k = k.size(-1)
    attn_weight = q.matmul(k) * softmax_scale
    if attn_bias is not None:
        _s_q = max(0, attn_bias.size(2) - s_q)
        _s_k = max(0, attn_bias.size(3) - s_k)
        attn_bias = attn_bias[:, :, _s_q:, _s_k:]
        if (
            attn_bias.size(-1) != 1
            and attn_bias.size(-1) != s_k
            or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
        ):
            raise RuntimeError(
                f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
            )
        attn_weight = attn_weight + attn_bias
    min_val = torch.finfo(q.dtype).min
    if key_padding_mask is not None:
        if attn_bias is not None:
            warnings.warn(
                "Propogating key_padding_mask to the attention module "
                + "and applying it within the attention module can cause "
                + "unneccessary computation/memory usage. Consider integrating "
                + "into attn_bias once and passing that to each attention "
                + "module instead."
            )
        attn_weight = attn_weight.masked_fill(
            ~key_padding_mask.view((b, 1, 1, s_k)), min_val
        )
    if is_causal and (not q.size(2) == 1):
        s = max(s_q, s_k)
        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
        causal_mask = causal_mask.tril()
        causal_mask = causal_mask.to(torch.bool)
        causal_mask = ~causal_mask
        causal_mask = causal_mask[-s_q:, -s_k:]
        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
    attn_weight = torch.softmax(attn_weight, dim=-1)
    if dropout_p:
        attn_weight = torch.nn.functional.dropout(
            attn_weight, p=dropout_p, training=training, inplace=True
        )
    out = attn_weight.to(v.dtype).matmul(v)
    out = rearrange(out, "b h s d -> b s (h d)")
    if needs_weights:
        return (out, attn_weight, past_key_value)
    return (out, None, past_key_value)


def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
    for tensor in tensors:
        if tensor.dtype not in valid_dtypes:
            raise TypeError(
                f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
            )
        if not tensor.is_cuda:
            raise TypeError(
                f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
            )


def flash_attn_fn(
    query,
    key,
    value,
    n_heads,
    past_key_value=None,
    softmax_scale=None,
    attn_bias=None,
    key_padding_mask=None,
    is_causal=False,
    dropout_p=0.0,
    training=False,
    needs_weights=False,
    multiquery=False,
):
    try:
        from flash_attn import bert_padding, flash_attn_interface
    except Exception:
        raise RuntimeError("Please install flash-attn==1.0.3.post0")
    check_valid_inputs(query, key, value)
    if past_key_value is not None:
        if len(past_key_value) != 0:
            key = torch.cat([past_key_value[0], key], dim=1)
            value = torch.cat([past_key_value[1], value], dim=1)
        past_key_value = (key, value)
    if attn_bias is not None:
        _s_q = max(0, attn_bias.size(2) - query.size(1))
        _s_k = max(0, attn_bias.size(3) - key.size(1))
        attn_bias = attn_bias[:, :, _s_q:, _s_k:]
    if attn_bias is not None:
        raise NotImplementedError("attn_bias not implemented for flash attn.")
    (batch_size, seqlen) = query.shape[:2]
    if key_padding_mask is None:
        key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
    query_padding_mask = key_padding_mask[:, -query.size(1) :]
    (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
        query, query_padding_mask
    )
    query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
    (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
        key, key_padding_mask
    )
    key_unpad = rearrange(
        key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
    )
    (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
    value_unpad = rearrange(
        value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
    )
    if multiquery:
        key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
        value_unpad = value_unpad.expand(
            value_unpad.size(0), n_heads, value_unpad.size(-1)
        )
    dropout_p = dropout_p if training else 0.0
    reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
    output_unpad = flash_attn_interface.flash_attn_unpadded_func(
        query_unpad,
        key_unpad,
        value_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        dropout_p,
        softmax_scale=softmax_scale,
        causal=reset_is_causal,
        return_attn_probs=needs_weights,
    )
    output = bert_padding.pad_input(
        rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
    )
    return (output, None, past_key_value)


def triton_flash_attn_fn(
    query,
    key,
    value,
    n_heads,
    past_key_value=None,
    softmax_scale=None,
    attn_bias=None,
    key_padding_mask=None,
    is_causal=False,
    dropout_p=0.0,
    training=False,
    needs_weights=False,
    multiquery=False,
):
    try:
        from .flash_attn_triton import flash_attn_func
    except Exception:
        _installed = False
        if version.parse(torch.__version__) < version.parse("2.0.0"):
            _installed = True
            try:
                from flash_attn.flash_attn_triton import flash_attn_func
            except Exception:
                _installed = False
        if not _installed:
            raise RuntimeError(
                "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
            )
    check_valid_inputs(query, key, value)
    if past_key_value is not None:
        if len(past_key_value) != 0:
            key = torch.cat([past_key_value[0], key], dim=1)
            value = torch.cat([past_key_value[1], value], dim=1)
        past_key_value = (key, value)
    if attn_bias is not None:
        _s_q = max(0, attn_bias.size(2) - query.size(1))
        _s_k = max(0, attn_bias.size(3) - key.size(1))
        attn_bias = attn_bias[:, :, _s_q:, _s_k:]
    if dropout_p:
        raise NotImplementedError("Dropout not implemented for attn_impl: triton.")
    if needs_weights:
        raise NotImplementedError("attn_impl: triton cannot return attn weights.")
    if key_padding_mask is not None:
        warnings.warn(
            "Propagating key_padding_mask to the attention module "
            + "and applying it within the attention module can cause "
            + "unnecessary computation/memory usage. Consider integrating "
            + "into attn_bias once and passing that to each attention "
            + "module instead."
        )
        (b_size, s_k) = key_padding_mask.shape[:2]
        if attn_bias is None:
            attn_bias = query.new_zeros(b_size, 1, 1, s_k)
        attn_bias = attn_bias.masked_fill(
            ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
        )
    query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
    key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
    value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
    if multiquery:
        key = key.expand(*key.shape[:2], n_heads, key.size(-1))
        value = value.expand(*value.shape[:2], n_heads, value.size(-1))
    reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
    attn_output = flash_attn_func(
        query, key, value, attn_bias, reset_is_causal, softmax_scale
    )
    output = attn_output.view(*attn_output.shape[:2], -1)
    return (output, None, past_key_value)


class MultiheadAttention(nn.Module):
    """Multi-head self attention.

    Using torch or triton attention implementation enables user to also use
    additive bias.
    """

    def __init__(
        self,
        config,
        prefix,
        weights,
    ):
        super().__init__()
        attn_impl = config.attn_config.attn_impl
        self.attn_impl = config.attn_config.attn_impl
        self.clip_qkv = config.attn_config.clip_qkv
        self.qk_ln = config.attn_config.qk_ln
        self.d_model = config.d_model
        d_model = config.d_model
        self.n_heads = config.n_heads
        self.softmax_scale = config.attn_config.softmax_scale
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
        self.attn_dropout_p = config.attn_config.attn_pdrop

        if self.n_heads % weights.process_group.size() != 0:
            raise ValueError(
                f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
                f"and `num_shards`: {weights.process_group.size()}"
            )
        self.n_heads = self.n_heads // weights.process_group.size()
        self.Wqkv = load_col(
            config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
        )
        if self.qk_ln:
            bias = not config.no_bias
            hidden_size = config.d_model
            head_dim = hidden_size // self.n_heads

            self.q_ln = LPLayerNorm(
                d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
            )
            self.k_ln = LPLayerNorm(
                self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
            )
        if self.attn_impl == "flash":
            self.attn_fn = flash_attn_fn
        elif self.attn_impl == "triton":
            self.attn_fn = triton_flash_attn_fn
        elif self.attn_impl == "torch":
            self.attn_fn = scaled_multihead_dot_product_attention
        else:
            raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
        self.out_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.out_proj",
            weights=weights,
            bias=not config.no_bias,
        )

    def forward(
        self,
        x,
        past_key_value=None,
        attn_bias=None,
        attention_mask=None,
        is_causal=True,
        needs_weights=False,
    ):
        qkv = self.Wqkv(x)
        if self.clip_qkv:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        (query, key, value) = qkv.chunk(3, dim=2)

        key_padding_mask = attention_mask
        if self.qk_ln:
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)
        (context, attn_weights, past_key_value) = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            attn_bias=attn_bias,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p,
            training=self.training,
            needs_weights=needs_weights,
        )
        out = self.out_proj(context)
        return (out, attn_weights, past_key_value)


class MultiQueryAttention(nn.Module):
    """Multi-Query self attention.

    Using torch or triton attention implementation enables user to also use
    additive bias.
    """

    def __init__(self, config, prefix, weights, verbose=False):
        super().__init__()
        attn_impl = config.attn_config.attn_impl
        self.attn_impl = config.attn_config.attn_impl
        self.clip_qkv = config.attn_config.clip_qkv
        self.qk_ln = config.attn_config.qk_ln
        self.d_model = config.d_model
        d_model = config.d_model
        self.n_heads = config.n_heads
        self.softmax_scale = config.attn_config.softmax_scale
        if self.softmax_scale is None:
            self.softmax_scale = 1 / math.sqrt(self.head_dim)
        self.attn_dropout_p = config.attn_config.attn_pdrop
        # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
        self.Wqkv = TensorParallelColumnLinear.load(
            config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
        )
        (d_model, d_model + self.head_dim)
        if self.qk_ln:
            raise NotImplementedError("qk_ln not supported")
        if self.attn_impl == "flash":
            self.attn_fn = flash_attn_fn
        elif self.attn_impl == "triton":
            self.attn_fn = triton_flash_attn_fn
            if verbose:
                warnings.warn(
                    "While `attn_impl: triton` can be faster than `attn_impl: flash` "
                    + "it uses more memory. When training larger models this can trigger "
                    + "alloc retries which hurts performance. If encountered, we recommend "
                    + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
                )
        elif self.attn_impl == "torch":
            self.attn_fn = scaled_multihead_dot_product_attention
            if torch.cuda.is_available() and verbose:
                warnings.warn(
                    "Using `attn_impl: torch`. If your model does not use `alibi` or "
                    + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
                    + "we recommend using `attn_impl: triton`."
                )
        else:
            raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
        self.out_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.out_proj",
            weights=weights,
            bias=not config.no_bias,
        )
        # self.out_proj._is_residual = True

    def forward(
        self,
        x,
        past_key_value=None,
        attn_bias=None,
        attention_mask=None,
        is_causal=True,
        needs_weights=False,
    ):
        qkv = self.Wqkv(x)
        if self.clip_qkv:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        (query, key, value) = qkv.split(
            [self.d_model, self.head_dim, self.head_dim], dim=2
        )
        key_padding_mask = attention_mask
        if self.qk_ln:
            dtype = query.dtype
            query = self.q_ln(query).to(dtype)
            key = self.k_ln(key).to(dtype)
        (context, attn_weights, past_key_value) = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            past_key_value=past_key_value,
            softmax_scale=self.softmax_scale,
            attn_bias=attn_bias,
            key_padding_mask=key_padding_mask,
            is_causal=is_causal,
            dropout_p=self.attn_dropout_p,
            training=self.training,
            needs_weights=needs_weights,
            multiquery=True,
        )
        return (self.out_proj(context), attn_weights, past_key_value)


def attn_bias_shape(
    attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
):
    if attn_impl == "flash":
        return None
    elif attn_impl in ["torch", "triton"]:
        if alibi:
            if (prefix_lm or not causal) or use_sequence_id:
                return (1, n_heads, seq_len, seq_len)
            return (1, n_heads, 1, seq_len)
        elif prefix_lm or use_sequence_id:
            return (1, 1, seq_len, seq_len)
        return None
    else:
        raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")


def build_attn_bias(
    attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
):
    if attn_impl == "flash":
        return None
    elif attn_impl in ["torch", "triton"]:
        if alibi:
            (device, dtype) = (attn_bias.device, attn_bias.dtype)
            attn_bias = attn_bias.add(
                build_alibi_bias(
                    n_heads,
                    seq_len,
                    full=not causal,
                    alibi_bias_max=alibi_bias_max,
                    device=device,
                    dtype=dtype,
                )
            )
        return attn_bias
    else:
        raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")


def gen_slopes(n_heads, alibi_bias_max=8, device=None):
    _n_heads = 2 ** math.ceil(math.log2(n_heads))
    m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
    m = m.mul(alibi_bias_max / _n_heads)
    slopes = 1.0 / torch.pow(2, m)
    if _n_heads != n_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
    return slopes.view(1, n_heads, 1, 1)


def build_alibi_bias(
    n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
):
    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
        1, 1, 1, seq_len
    )
    if full:
        alibi_bias = alibi_bias - torch.arange(
            1 - seq_len, 1, dtype=torch.int32, device=device
        ).view(1, 1, seq_len, 1)
        alibi_bias = alibi_bias.abs().mul(-1)
    slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
    alibi_bias = alibi_bias * slopes
    return alibi_bias.to(dtype=dtype)


ATTN_CLASS_REGISTRY = {
    "multihead_attention": MultiheadAttention,
    "multiquery_attention": MultiQueryAttention,
}

"""GPT Blocks used for the GPT Model."""


class MPTMLP(nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
        self.up_proj = TensorParallelColumnLinear.load(
            config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias
        )
        self.act = nn.GELU(approximate="none")
        # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
        self.down_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.down_proj",
            weights=weights,
            bias=not config.no_bias,
        )
        # self.down_proj._is_residual = True

    def forward(self, x):
        return self.down_proj(self.act(self.up_proj(x)))


class MPTBlock(nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        self.prefix = prefix
        if config.attn_config.attn_type != "multihead_attention":
            raise NotImplementedError(
                f"""Not implemented attn {config.attn_config.attn_type}"""
            )
        resid_pdrop = config.resid_pdrop
        if config.no_bias:
            self.norm_1 = nn.LayerNorm.load_no_bias(
                prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
            )
            self.norm_2 = nn.LayerNorm.load_no_bias(
                prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
            )
        else:
            self.norm_1 = nn.LayerNorm.load(
                prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
            )
            self.norm_2 = nn.LayerNorm.load(
                prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
            )
        self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
        self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
        self.resid_attn_dropout = nn.Dropout(resid_pdrop)
        self.resid_ffn_dropout = nn.Dropout(resid_pdrop)

    def forward(
        self,
        x: torch.Tensor,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attn_bias: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.ByteTensor] = None,
        is_causal: bool = True,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
        a = self.norm_1(x)
        (b, attn_weights, past_key_value) = self.attn(
            a,
            past_key_value=past_key_value,
            attn_bias=attn_bias,
            attention_mask=attention_mask,
            is_causal=is_causal,
        )
        x = x + self.resid_attn_dropout(b)
        m = self.norm_2(x)
        n = self.ffn(m)
        x = x + self.resid_ffn_dropout(n)
        return (x, attn_weights, past_key_value)


def _cast_if_autocast_enabled(tensor):
    if torch.is_autocast_enabled():
        if tensor.device.type == "cuda":
            dtype = torch.get_autocast_gpu_dtype()
        elif tensor.device.type == "cpu":
            dtype = torch.get_autocast_cpu_dtype()
        else:
            raise NotImplementedError()
        return tensor.to(dtype=dtype)
    return tensor


class LPLayerNorm(torch.nn.LayerNorm):
    def __init__(
        self,
        normalized_shape,
        eps=1e-05,
        elementwise_affine=True,
        device=None,
        dtype=None,
        bias: Optional[bool] = True,
        prefix=None,
        weights=None,
    ):
        super().__init__(
            normalized_shape=normalized_shape,
            eps=eps,
            elementwise_affine=elementwise_affine,
            device=device,
            dtype=dtype,
            bias=bias,
        )
        if weights is not None:
            self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
            if bias:
                self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
            self.normalized_shape = self.weight.shape

    def forward(self, x):
        module_device = x.device
        downcast_x = _cast_if_autocast_enabled(x)
        downcast_weight = (
            _cast_if_autocast_enabled(self.weight)
            if self.weight is not None
            else self.weight
        )
        downcast_bias = (
            _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
        )
        with torch.autocast(enabled=False, device_type=module_device.type):
            return torch.nn.functional.layer_norm(
                downcast_x,
                self.normalized_shape,
                downcast_weight,
                downcast_bias,
                self.eps,
            )


def rms_norm(x, weight=None, eps=1e-05):
    output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
    if weight is not None:
        return output * weight
    return output


class RMSNorm(torch.nn.Module):
    def __init__(
        self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
    ):
        super().__init__()
        self.eps = eps
        if weight:
            self.weight = torch.nn.Parameter(
                torch.ones(normalized_shape, dtype=dtype, device=device)
            )
        else:
            self.register_parameter("weight", None)

    def forward(self, x):
        return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)


class LPRMSNorm(RMSNorm):
    def __init__(
        self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
    ):
        super().__init__(
            normalized_shape=normalized_shape,
            eps=eps,
            weight=weight,
            dtype=dtype,
            device=device,
        )

    def forward(self, x):
        downcast_x = _cast_if_autocast_enabled(x)
        downcast_weight = (
            _cast_if_autocast_enabled(self.weight)
            if self.weight is not None
            else self.weight
        )
        with torch.autocast(enabled=False, device_type=x.device.type):
            return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)


NORM_CLASS_REGISTRY = {
    "layernorm": torch.nn.LayerNorm,
    "low_precision_layernorm": LPLayerNorm,
    "rmsnorm": RMSNorm,
    "low_precision_rmsnorm": LPRMSNorm,
}

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


class MPTPreTrainedModel(PreTrainedModel):
    base_model_prefix = "model"
    _no_split_modules = ["MPTBlock"]


class MPTModel(MPTPreTrainedModel):
    def __init__(self, prefix: str, config, weights):
        # config._validate_config()
        super().__init__(config)
        self.world_size = weights.process_group.size()
        self.rank = weights.process_group.rank()
        self.n_heads = config.n_heads
        self.attn_impl = config.attn_config.attn_impl
        self.prefix_lm = config.attn_config.prefix_lm
        self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id
        self.alibi = config.attn_config.alibi
        self.alibi_bias_max = config.attn_config.alibi_bias_max
        if config.init_device == "mixed":
            # TODO: reimplement mixed device initialization
            # dist.get_local_rank() == 0:
            if True:
                config.init_device = "cpu"
            else:
                config.init_device = "meta"
        if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
            norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
            raise NotImplementedError(
                f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
            )
        if config.norm_type.lower() != "low_precision_layernorm":
            raise NotImplementedError(
                f"Requested norm type ({config.norm_type}) is not implemented within this repo."
            )

        self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)

        if not self.alibi:
            self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
        self.blocks = nn.ModuleList(
            [
                MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
                for i in range(config.n_layers)
            ]
        )
        if config.no_bias:
            self.norm_f = nn.LayerNorm.load_no_bias(
                prefix="transformer.norm_f", weights=weights, eps=EPS
            )
        else:
            self.norm_f = nn.LayerNorm.load(
                prefix="transformer.norm_f", weights=weights, eps=EPS
            )
        self.is_causal = not self.prefix_lm
        self._attn_bias_initialized = False
        self.attn_bias = None
        self.attn_bias_shape = attn_bias_shape(
            self.attn_impl,
            config.n_heads,
            config.max_seq_len,
            self.alibi,
            prefix_lm=self.prefix_lm,
            causal=self.is_causal,
            use_sequence_id=self.attn_uses_sequence_id,
        )
        if config.no_bias:
            for module in self.modules():
                if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
                    if config.verbose:
                        warnings.warn(f"Removing bias ({module.bias}) from {module}.")
                    module.register_parameter("bias", None)
        if hasattr(self.config, "verbose"):
            if config.verbose and config.verbose > 2:
                print(self)
        if "verbose" not in self.config.init_config:
            self.config.init_config["verbose"] = self.config.verbose
        if self.config.init_config["verbose"] > 1:
            init_fn_name = self.config.init_config["name"]
            warnings.warn(f"Using {init_fn_name} initialization.")

    @torch.no_grad()
    def _attn_bias(
        self,
        device,
        dtype,
        attention_mask: Optional[torch.ByteTensor] = None,
        prefix_mask: Optional[torch.ByteTensor] = None,
        sequence_id: Optional[torch.LongTensor] = None,
    ):
        if not self._attn_bias_initialized:
            if self.attn_bias_shape:
                self.attn_bias = torch.zeros(
                    self.attn_bias_shape, device=device, dtype=dtype
                )
                self.attn_bias = build_attn_bias(
                    self.attn_impl,
                    self.attn_bias,
                    self.config.n_heads,
                    self.config.max_seq_len,
                    causal=self.is_causal,
                    alibi=self.alibi,
                    alibi_bias_max=self.alibi_bias_max,
                )
                assert self.n_heads % self.world_size == 0
                block_size = self.n_heads // self.world_size
                self.attn_bias = self.attn_bias[
                    :, self.rank * block_size : (self.rank + 1) * block_size
                ]
            self._attn_bias_initialized = True
        if self.attn_impl == "flash":
            return (self.attn_bias, attention_mask)
        if self.attn_bias is not None:
            self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
        attn_bias = self.attn_bias
        if self.prefix_lm:
            assert isinstance(attn_bias, torch.Tensor)
            assert isinstance(prefix_mask, torch.Tensor)
            attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
        if self.attn_uses_sequence_id and sequence_id is not None:
            assert isinstance(attn_bias, torch.Tensor)
            attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
        if attention_mask is not None:
            s_k = attention_mask.shape[-1]
            if attn_bias is None:
                attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
            else:
                _s_k = max(0, attn_bias.size(-1) - s_k)
                attn_bias = attn_bias[:, :, :, _s_k:]
            if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
                raise ValueError(
                    f"attention_mask shape={attention_mask.shape} "
                    + f"and prefix_mask shape={prefix_mask.shape} are not equal."
                )
            min_val = torch.finfo(attn_bias.dtype).min
            attn_bias = attn_bias.masked_fill(
                ~attention_mask.view(-1, 1, 1, s_k), min_val
            )
        return (attn_bias, None)

    def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
        (s_k, s_q) = attn_bias.shape[-2:]
        if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
            raise ValueError(
                "attn_bias does not match the expected shape. "
                + f"The last two dimensions should both be {self.config.max_length} "
                + f"but are {s_k} and {s_q}."
            )
        seq_len = prefix_mask.shape[-1]
        if seq_len > self.config.max_seq_len:
            raise ValueError(
                f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
            )
        attn_bias = attn_bias[..., :seq_len, :seq_len]
        causal = torch.tril(
            torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
        ).view(1, 1, seq_len, seq_len)
        prefix = prefix_mask.view(-1, 1, 1, seq_len)
        cannot_attend = ~torch.logical_or(causal, prefix.bool())
        min_val = torch.finfo(attn_bias.dtype).min
        attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
        return attn_bias

    def _apply_sequence_id(
        self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
    ):
        seq_len = sequence_id.shape[-1]
        if seq_len > self.config.max_seq_len:
            raise ValueError(
                f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
            )
        attn_bias = attn_bias[..., :seq_len, :seq_len]
        cannot_attend = torch.logical_not(
            torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
        ).unsqueeze(1)
        min_val = torch.finfo(attn_bias.dtype).min
        attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
        return attn_bias

    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
        attention_mask: Optional[torch.ByteTensor] = None,
        prefix_mask: Optional[torch.ByteTensor] = None,
        sequence_id: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        use_cache: Optional[bool] = None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.return_dict
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        if attention_mask is not None:
            attention_mask = attention_mask.bool()
        if prefix_mask is not None:
            prefix_mask = prefix_mask.bool()
        if not return_dict:
            raise NotImplementedError(
                "return_dict False is not implemented yet for MPT"
            )
        if output_attentions:
            if self.attn_impl != "torch":
                raise NotImplementedError(
                    "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
                )
        if (
            attention_mask is not None
            and attention_mask[:, 0].sum() != attention_mask.shape[0]
            and self.training
        ):
            raise NotImplementedError(
                "MPT does not support training with left padding."
            )
        if self.prefix_lm and prefix_mask is None:
            raise ValueError(
                "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
            )
        if self.training:
            if self.attn_uses_sequence_id and sequence_id is None:
                raise ValueError(
                    "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
                    + "and the model is in train mode."
                )
            elif self.attn_uses_sequence_id is False and sequence_id is not None:
                warnings.warn(
                    "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
                    + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
                )
        S = input_ids.size(1)
        assert (
            S <= self.config.max_seq_len
        ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
        tok_emb = self.wte(input_ids)
        if self.alibi:
            x = tok_emb
        else:
            past_position = 0
            if past_key_values is not None:
                if len(past_key_values) != self.config.n_layers:
                    raise ValueError(
                        "past_key_values must provide a past_key_value for each attention "
                        + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
                    )
                past_position = past_key_values[0][0].size(1)
                if self.attn_impl == "torch":
                    past_position = past_key_values[0][0].size(3)
            if S + past_position > self.config.max_seq_len:
                raise ValueError(
                    f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
                )
            pos = torch.arange(
                past_position,
                S + past_position,
                dtype=torch.long,
                device=input_ids.device,
            ).unsqueeze(0)
            if attention_mask is not None:
                pos = torch.clamp(
                    pos
                    - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
                        :, past_position:
                    ],
                    min=0,
                )
            pos_emb = self.wpe(pos)
            x = tok_emb + pos_emb
        (attn_bias, attention_mask) = self._attn_bias(
            device=x.device,
            dtype=torch.float32,
            attention_mask=attention_mask,
            prefix_mask=prefix_mask,
            sequence_id=sequence_id,
        )
        if use_cache and past_key_values is None:
            past_key_values = [() for _ in range(self.config.n_layers)]
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        for b_idx, block in enumerate(self.blocks):
            if output_hidden_states:
                assert all_hidden_states is not None
                all_hidden_states = all_hidden_states + (x,)
            past_key_value = (
                past_key_values[b_idx] if past_key_values is not None else None
            )
            (x, attn_weights, past_key_value) = block(
                x,
                past_key_value=past_key_value,
                attn_bias=attn_bias,
                attention_mask=attention_mask,
                is_causal=self.is_causal,
            )
            if past_key_values is not None:
                past_key_values[b_idx] = past_key_value
            if output_attentions:
                assert all_self_attns is not None
                all_self_attns = all_self_attns + (attn_weights,)
        x = self.norm_f(x)
        if output_hidden_states:
            assert all_hidden_states is not None
            all_hidden_states = all_hidden_states + (x,)
        return BaseModelOutputWithPast(
            last_hidden_state=x,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class MPTForCausalLM(MPTPreTrainedModel):
    def __init__(self, prefix: str, config, weights):
        super().__init__(config)

        if not prefix:
            prefix = "transformer"
        else:
            prefix = f"{prefix}.transformer"

        if not config.tie_word_embeddings:
            raise ValueError("MPTForCausalLM only supports tied word embeddings")
        self.transformer = MPTModel(prefix, config, weights)
        self.lm_head = SpeculativeHead.load(
            config, prefix=f"{prefix}.wte", weights=weights
        )
        self.logit_scale = None
        if config.logit_scale is not None:
            logit_scale = config.logit_scale
            if isinstance(logit_scale, str):
                if logit_scale == "inv_sqrt_d_model":
                    logit_scale = 1 / math.sqrt(config.d_model)
                else:
                    raise ValueError(
                        f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
                    )
            self.logit_scale = logit_scale

    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
        attention_mask: Optional[torch.ByteTensor] = None,
        prefix_mask: Optional[torch.ByteTensor] = None,
        sequence_id: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        use_cache: Optional[bool] = None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.return_dict
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        outputs = self.transformer(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            prefix_mask=prefix_mask,
            sequence_id=sequence_id,
            return_dict=return_dict,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            use_cache=use_cache,
        )
        logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
        if self.logit_scale is not None:
            if self.logit_scale == 0:
                warnings.warn(
                    f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
                )
            logits *= self.logit_scale
        loss = None
        if labels is not None:
            labels = torch.roll(labels, shifts=-1)
            labels[:, -1] = -100
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
            )
        return (
            CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            ),
            speculative_logits,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
    ):
        if inputs_embeds is not None:
            raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
        attention_mask = kwargs["attention_mask"].bool()
        if attention_mask[:, -1].sum() != attention_mask.shape[0]:
            raise NotImplementedError(
                "MPT does not support generation with right padding."
            )
        if self.transformer.attn_uses_sequence_id and self.training:
            sequence_id = torch.zeros_like(input_ids[:1])
        else:
            sequence_id = None
        if past_key_values is not None:
            input_ids = input_ids[:, -1].unsqueeze(-1)
        if self.transformer.prefix_lm:
            prefix_mask = torch.ones_like(attention_mask)
            if kwargs.get("use_cache") is False:
                raise NotImplementedError(
                    "MPT with prefix_lm=True does not support use_cache=False."
                )
        else:
            prefix_mask = None
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "prefix_mask": prefix_mask,
            "sequence_id": sequence_id,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache", True),
        }

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        """Used by HuggingFace generate when using beam search with kv-caching.

        See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
        for an example in transformers.
        """
        reordered_past = []
        for layer_past in past_key_values:
            reordered_past += [
                tuple(
                    (past_state.index_select(0, beam_idx) for past_state in layer_past)
                )
            ]
        return reordered_past