# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/adapters/lora.py
# License:  Apache License Version 2.0, January 2004

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union

from loguru import logger
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.utils.log import log_master

from text_generation_server.adapters.config import AdapterConfig, ModuleMap

from text_generation_server.adapters.weights import (
    AdapterBatchMetadata,
    AdapterWeights,
    BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
    BGMV_MAX_RANK,
    MAX_RANK_CUSTOM,
    get_tmp_tensors,
    orient_for_rank,
    pad_rank,
    use_cutlass_shrink,
    has_sgmv,
)


def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
    block_size = size // world_size
    start = offset + rank * block_size
    stop = offset + (rank + 1) * block_size
    return start, stop


def shard_on_dim(
    t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
    world_size = process_group.size()
    rank = process_group.rank()

    size = t.shape[dim]
    start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)

    if dim == 0:
        tensor = t[start:stop]
    elif dim == 1:
        tensor = t[:, start:stop]
    else:
        raise NotImplementedError("Let's make that generic when needed")

    return tensor


def shard_lora_weights(
    weights_a: List[torch.Tensor],
    weights_b: List[torch.Tensor],
    split_dim: int,
    process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    # [hidden_size, r]
    weights_a = [
        shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
    ]

    # [r, hidden_size]
    weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]

    return weights_a, weights_b


@dataclass
class LoraConfig(AdapterConfig):
    r: int
    target_modules: Optional[Union[List[str], str]]
    fan_in_fan_out: bool
    lora_alpha: int
    use_rslora: bool

    def map_weights_for_model(
        self,
        adapter_weights: Dict[int, AdapterWeights],
        weight_names: Tuple[str],
    ) -> Tuple[ModuleMap, Set[str]]:
        adapter_weight_names = set()
        module_map = {}
        for weight_name in weight_names:
            lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
            lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
            if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
                continue

            module_map[weight_name] = {
                "lora_A": (adapter_weights[lora_a_name], lora_a_name),
                "lora_B": (adapter_weights[lora_b_name], lora_b_name),
            }
            adapter_weight_names.add(lora_a_name)
            adapter_weight_names.add(lora_b_name)
        return module_map, adapter_weight_names

    @classmethod
    def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
        hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
        return cls(
            base_model_name_or_path=hf_config.base_model_name_or_path,
            r=hf_config.r,
            target_modules=hf_config.target_modules,
            fan_in_fan_out=hf_config.fan_in_fan_out,
            lora_alpha=hf_config.lora_alpha,
            use_rslora=(
                hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
            ),
        )


class LoraWeights(AdapterWeights):
    """LoRA weights for a single adapter merged across all layers."""

    def __init__(
        self,
        weights_a: List[torch.Tensor],
        weights_b: List[torch.Tensor],
        adapter_config: LoraConfig,
    ):
        self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
        self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

        self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
        self._is_transposed = False

        # [num_layers, hidden_size, r]
        weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
        self._weights_a = torch.stack(weights_a)

        # [num_layers, r, hidden_size]
        self._weights_b = torch.stack(weights_b)

        self.adapter_config = adapter_config

    @property
    def weights_a(self) -> torch.Tensor:
        if self._is_transposed:
            self._transpose_weights()
        return self._weights_a

    @property
    def weights_b(self) -> torch.Tensor:
        if self._is_transposed:
            self._transpose_weights()
        return self._weights_b

    @property
    def weights_a_t(self) -> torch.Tensor:
        if not self._is_transposed:
            self._transpose_weights()
        return self._weights_a

    @property
    def weights_b_t(self) -> torch.Tensor:
        if not self._is_transposed:
            self._transpose_weights()
        return self._weights_b

    def _transpose_weights(self):
        if self._use_cutlass_shrink:
            # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
            self._weights_a = self._weights_a.transpose(1, 2).contiguous()
        self._weights_b = self._weights_b.transpose(1, 2).contiguous()
        self._is_transposed = not self._is_transposed

    @classmethod
    def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
        return [BatchLoraWeights]

    # prepare pre-loaded lora weights for use in the model.
    #
    # this method processes and organizes lora weights for a specific layer type across all layers:
    # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
    # - retrieves weights from `module_map` based on the `layer_type`.
    # - processes `nlayers` number of layers.
    # - converts weights to the specified `dtype`.
    # - shards weights across `world_size` number of processes using the `process_group`.
    # - maps weights to specific layers using `target_to_layer`.
    # - tracks `unused_weight_names` to identify any unused weights.
    #
    # the method handles weight transposition, scaling, and padding to ensure compatibility
    # with SGMV or BGMV operations.
    @classmethod
    def prepare_weights(
        cls,
        config: LoraConfig,
        module_map: Dict[str, Dict],
        layer_type: str,
        unused_weight_names: Set[str],
        nlayers: int,
        dtype: torch.dtype,
        world_size: int,
        process_group: ProcessGroup,
        target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
    ) -> Optional[AdapterWeights]:
        lora_a_list = [None] * nlayers
        lora_b_list = [None] * nlayers

        for layer_id in range(nlayers):
            key = (layer_id, layer_type)
            if key not in target_to_layer:
                # There is no layer of this type in the model
                log_master(
                    logger.warning,
                    f"Key specified in lora weights but not found in base model: {key}",
                )
                return None

            weight_name, layer = target_to_layer[key]
            base_weight = layer.base_layer.linear.weight
            base_device = base_weight.device

            if weight_name not in module_map:
                # There is no LoRA weight for this layer type in the adapter
                return None

            lora_a, lora_a_name = module_map[weight_name]["lora_A"]
            lora_a = lora_a.to(base_device, dtype)

            lora_b, lora_b_name = module_map[weight_name]["lora_B"]
            lora_b = lora_b.to(base_device, dtype)

            scale = get_scaling_factor(
                config.lora_alpha,
                config.r,
                uses_rslora=config.use_rslora,
            )

            unused_weight_names.discard(lora_a_name)
            unused_weight_names.discard(lora_b_name)

            # Merge scaling factor into lora_b due to associativity of matrix multiplication:
            # (A * B) * C = A * (B * C)
            lora_a_list[layer_id] = lora_a.transpose(0, 1)
            lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

        # pad lora ranks to be compatible with sgmv
        lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
        lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]

        if lora_a_list:
            # update rank if it was padded
            padded_rank = lora_a_list[0].size(1)
            config.r = padded_rank

        return LoraWeights(
            *shard_lora_weights(
                weights_a=lora_a_list,
                weights_b=lora_b_list,
                split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
                process_group=process_group,
            ),
            config,
        )


@dataclass
class RankSegments:
    rank: int

    lora_a_ptr: torch.Tensor
    lora_b_ptr: torch.Tensor

    # prefill (sgmv)
    tmp_shrink: torch.Tensor
    tmp_expand: torch.Tensor
    segment_starts: torch.Tensor
    segment_ends: torch.Tensor

    # decode (bgmv)
    indices: torch.Tensor


@dataclass
class BatchLoraWeights(BatchAdapterWeights):
    lora_a: Dict[int, torch.Tensor]
    lora_b: Dict[int, torch.Tensor]
    adapter_index_configs: Dict[int, LoraConfig]
    rank_data: Dict[int, RankSegments]
    use_sgmv: bool

    def has_adapter(self, adapter_index: int) -> bool:
        return adapter_index in self.adapter_index_configs

    def can_vectorize(self, pg: ProcessGroup) -> bool:
        return all(
            rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
            for rank_data in self.rank_data.values()
        )

    @classmethod
    def load(
        self,
        adapter_weights: Dict[int, AdapterWeights],
        meta: AdapterBatchMetadata,
        prefill: bool,
        prefill_head_indices: Optional[torch.Tensor],
    ) -> Optional["BatchLoraWeights"]:
        adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
        adapter_weights = {
            k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
        }
        if not adapter_weights:
            return None

        first_weights = next(iter(adapter_weights.values()))
        device = first_weights.weights_a.device
        segment_indices = meta.segment_indices

        lora_a = {
            idx: adapter_weights[idx].weights_a
            for idx in segment_indices
            if idx in adapter_weights
        }
        lora_b = {
            idx: adapter_weights[idx].weights_b
            for idx in segment_indices
            if idx in adapter_weights
        }

        max_rank = max(
            (
                adapter_weights[idx].lora_a_r
                for idx in segment_indices
                if idx in adapter_weights
            ),
            default=0,
        )

        use_sgmv = False
        if prefill or max_rank > BGMV_MAX_RANK:
            if has_sgmv():
                use_sgmv = True
            lora_a_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_a.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )
            lora_b_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_b.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )
        else:
            lora_a_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_a_t.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )
            lora_b_ptr = torch.tensor(
                [
                    (
                        adapter_weights[idx].weights_b_t.data_ptr()
                        if idx in adapter_weights
                        else 0
                    )
                    for idx in segment_indices
                ],
                dtype=torch.int64,
                device=device,
            )

        adapter_index_configs = {
            idx: adapter_weights[idx].adapter_config
            for idx in segment_indices
            if idx in adapter_weights
        }

        adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

        rank_indices = defaultdict(list)
        for segment_idx, adapter_idx in enumerate(segment_indices):
            if adapter_idx not in adapter_weights:
                continue
            rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)

        if prefill_head_indices is not None:
            j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
            for head_index in prefill_head_indices:
                # j cannot go out of bounds as that would mean there are tokens without corresponding adapters
                if head_index < meta.adapter_segments[j]:
                    prefill_head_segment_ends[-1] += 1
                else:
                    prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
                    prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
                    j += 1

        rank_data = {}
        for rank, indices in rank_indices.items():
            tmp_shrink = None
            tmp_expand = None
            segment_starts = None
            segment_ends = None
            batch_indices = None

            if use_sgmv:
                lora_a_ptr_indices = lora_a_ptr[indices]
                tmp_shrink, tmp_expand = get_tmp_tensors(
                    lora_a_ptr_indices.size(0), rank, device
                )
                segment_starts = meta.adapter_segments[indices]
                segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
                if prefill_head_indices is not None:
                    for i, segment_index in enumerate(indices):
                        segment_starts[i] = prefill_head_segment_starts[segment_index]
                        segment_ends[i] = prefill_head_segment_ends[segment_index]
            else:
                rank_indices = set(indices)
                batch_indices = [
                    adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
                ]
                batch_indices = [
                    idx if idx in rank_indices else -1 for idx in batch_indices
                ]
                batch_indices = torch.tensor(
                    batch_indices, dtype=torch.int64, device=device
                )

            rank_data[rank] = RankSegments(
                rank=rank,
                tmp_shrink=tmp_shrink,
                tmp_expand=tmp_expand,
                lora_a_ptr=lora_a_ptr[indices],
                lora_b_ptr=lora_b_ptr[indices],
                segment_starts=segment_starts,
                segment_ends=segment_ends,
                indices=batch_indices,
            )

        return BatchLoraWeights(
            lora_a=lora_a,
            lora_b=lora_b,
            adapter_index_configs=adapter_index_configs,
            rank_data=rank_data,
            use_sgmv=use_sgmv,
        )


def get_scaling_factor(
    lora_alpha: int,
    r: int,
    uses_rslora: bool = False,
) -> float:
    """Computes the scaling factor for the lora weights."""
    if uses_rslora:
        return lora_alpha / (r**0.5)
    return lora_alpha / r


def _convert_lora(v: AdapterWeights) -> AdapterWeights:
    if hasattr(v, "lora_weights"):
        return v.lora_weights
    return v