# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/adapters/config.py
# License:  Apache License Version 2.0, January 2004

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple

import torch

from text_generation_server.adapters.weights import AdapterWeights

if TYPE_CHECKING:
    from text_generation_server.models.model import Model


@dataclass
class ModuleMap:
    module_name: str
    module_weights: Dict[str, Tuple[torch.Tensor, str]]


@dataclass
class AdapterConfig(ABC):
    base_model_name_or_path: str

    @abstractmethod
    def map_weights_for_model(
        self,
        adapter_weights: Dict[int, AdapterWeights],
        weight_names: Tuple[str],
    ) -> Tuple[ModuleMap, Set[str]]:
        pass

    @abstractmethod
    def load_batched_adapter_weights(
        self,
        model: "Model",
        module_map: ModuleMap,
        layer_type: str,
        unused_weight_names: Set[str],
        dynamic: bool,
    ) -> Optional[AdapterWeights]:
        pass