from typing import TYPE_CHECKING, Optional, List

import torch
import torch.distributed
from torch import nn
from torch.distributed import ProcessGroup

from text_generation_server.utils.sgmv import (
    add_lora_a_bgmv,
    add_lora_b_bgmv,
    has_sgmv,
    lora_a_sgmv_cutlass,
    lora_b_sgmv_cutlass,
    orient_for_rank,
)

if TYPE_CHECKING:
    from text_generation_server.adapters import AdapterBatchData
    from text_generation_server.adapters.lora import BatchLoraWeights


class LoraLinear(nn.Module):
    def __init__(
        self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
    ):
        super().__init__()
        self.base_layer = base_layer
        self.layer_id = layer_id
        self.process_group = process_group

    def forward_layer_type(
        self,
        result: torch.Tensor,
        input: torch.Tensor,
        adapter_data: "AdapterBatchData",
        layer_type: str,
        start_idx: int,
        end_idx: int,
    ) -> torch.Tensor:
        if adapter_data is None:
            return result
        data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)

        if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
            # In tensor-parallel configurations, each GPU processes a specific segment of the output.
            # The 'result' tensor represents the full output, which can vary in size based on
            # the layer type (e.g., attention vs. feed-forward layers). We define the current
            # segment using start_idx and end_idx. If the segment size doesn't match this GPU's
            # slice of 'result', we create a zero tensor of the correct size for LoRA computation.
            # This approach ensures accurate LoRA application across various layer sizes and
            # configurations, adapting to different model architectures and parallelization strategies.
            #
            # Example scenarios where this is necessary:
            # 1. The adapter's size doesn't evenly divide across GPUs.
            # 2. We're processing the last segment which might be smaller.
            # 3. Different projection layers (q, k, v) have different sizes.
            if end_idx - start_idx != result.shape[1]:
                proj = torch.zeros_like(result[:, start_idx:end_idx])
            else:
                proj = result

            for r, rank_segments in data.rank_data.items():
                lora_a_ptr = rank_segments.lora_a_ptr
                lora_b_ptr = rank_segments.lora_b_ptr

                if lora_a_ptr is None or lora_b_ptr is None:
                    raise ValueError("LoRA data is missing")

                if data.use_sgmv:
                    # Use SGMV for prefill
                    v = lora_a_sgmv_cutlass(
                        input,
                        rank_segments.tmp_shrink,
                        lora_a_ptr,
                        rank_segments.segment_starts,
                        rank_segments.segment_ends,
                        self.layer_id,
                        r,
                    )

                    if self.process_group.size() > 1:
                        v = self.collect_lora_a(v)

                    lora_b_sgmv_cutlass(
                        proj,
                        v,
                        rank_segments.tmp_expand,
                        lora_b_ptr,
                        rank_segments.segment_starts,
                        rank_segments.segment_ends,
                        self.layer_id,
                    )
                else:
                    # Use BGMV for decode
                    v = torch.zeros(
                        (input.size(0), r), dtype=input.dtype, device=input.device
                    )
                    # TODO: error with [-1, 0], but not [0, -1]
                    add_lora_a_bgmv(
                        v,
                        input,
                        lora_a_ptr,
                        rank_segments.indices,
                        self.layer_id,
                    )

                    if self.process_group.size() > 1:
                        v = self.collect_lora_a(v)

                    add_lora_b_bgmv(
                        proj,
                        v,
                        lora_b_ptr,
                        rank_segments.indices,
                        self.layer_id,
                    )

            if end_idx - start_idx != result.shape[1]:
                result[:, start_idx:end_idx] += proj
        else:
            for adapter_index in adapter_data.meta.adapter_set:
                if data is not None and data.has_adapter(adapter_index):
                    adapter_mask = (
                        (adapter_data.meta.adapter_indices == adapter_index)
                        .to(input.dtype)
                        .view(-1, 1)
                    )
                    layer_result = self.forward_lora(
                        input, data, adapter_index, adapter_mask
                    )
                    result[:, start_idx:end_idx] += layer_result

        return result

    def forward_lora(
        self,
        input: torch.Tensor,
        data: "BatchLoraWeights",
        adapter_index: int,
        adapter_mask: torch.Tensor,
    ) -> torch.Tensor:
        lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
        lora_b = data.lora_b[adapter_index][self.layer_id, :, :]

        lora_a = orient_for_rank(lora_a, lora_b.size(0))

        a_out = input @ lora_a
        if self.process_group.size() > 1:
            a_out = self.collect_lora_a(a_out)

        result = (a_out @ lora_b) * adapter_mask
        return result

    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError("Implemented in subclasses")


class TensorParallelMultiAdapterLinear(LoraLinear):
    def __init__(
        self,
        base_layer: nn.Module,
        layer_id: int,
        layer_names: List[str],
        sizes: List[int],
        process_group: ProcessGroup,
    ):
        super().__init__(base_layer, layer_id, process_group)
        self.layer_names = layer_names
        self.sizes = sizes

    @classmethod
    def load(
        cls,
        base_layer: nn.Module,
        layer_id: int,
        layer_names: List[str],
        sizes: List[int],
        process_group: ProcessGroup,
    ):
        return TensorParallelMultiAdapterLinear(
            base_layer, layer_id, layer_names, sizes, process_group
        )

    def forward(
        self, input: torch.Tensor, adapter_data: "AdapterBatchData"
    ) -> torch.Tensor:
        result = self.base_layer(input)

        # noop if no layer names are provided (e.g. for models without adapters)
        if self.layer_names is None:
            return result

        # handle models like Bloom that have inputs of shape
        # (batch_size, sequence_length, hidden_size)
        # we need to reshape them to (batch_size * sequence_length, hidden_size)
        # for the LoRA computation, then reshape back
        prev_shape = result.shape
        is_3d = len(input.shape) >= 3
        if is_3d:
            input = input.reshape(-1, input.shape[-1])
            result = result.reshape(-1, result.shape[-1])

        offset = 0
        for i, layer_name in enumerate(self.layer_names):
            start_idx = offset // self.process_group.size()
            # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
            # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
            # ensures correct slicing of the result tensor, accommodating variations like grouped-query
            # attention where k_proj and v_proj differ from q_proj. This allows precise application of
            # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
            # different projection sizes across layers and model architectures.
            if self.sizes is not None:
                offset += self.sizes[i]
                end_idx = offset // self.process_group.size()
            else:
                end_idx = result.shape[1]

            result = self.forward_layer_type(
                result, input, adapter_data, layer_name, start_idx, end_idx
            )

        if is_3d:
            result = result.reshape(prev_shape)

        return result

    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
        # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
        # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
        #
        # TODO(travis): this is not very efficient as we do an all-gather for every adapter,
        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:
        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
        gathered_tensors = [
            torch.empty_like(a_out) for _ in range(self.process_group.size())
        ]
        torch.distributed.all_gather(gathered_tensors, a_out)
        return torch.cat(gathered_tensors, dim=1)


class TensorParallelAdapterRowLinear(LoraLinear):
    def __init__(self, base_layer, layer_id, layer_name, process_group):
        super().__init__(base_layer, layer_id, process_group)
        self.layer_name = layer_name

    @classmethod
    def load(cls, base_layer, layer_id, layer_name, process_group):
        return cls(base_layer, layer_id, layer_name, process_group)

    def forward(
        self, input: torch.Tensor, adapter_data: "AdapterBatchData"
    ) -> torch.Tensor:
        result = self.base_layer(input)

        if self.layer_name is None:
            return result

        # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
        stride = result.shape[-1] // self.process_group.size()
        start_idx = self.process_group.rank() * stride
        end_idx = (self.process_group.rank() + 1) * stride

        self.forward_layer_type(
            result, input, adapter_data, self.layer_name, start_idx, end_idx
        )

        return result

    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
        # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
        # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
        #
        # TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:
        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
        torch.distributed.all_reduce(a_out, group=self.process_group)
        return a_out