import inspect
import torch

from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from loguru import logger

from text_generation_server.models.globals import (
    ATTENTION,
    PREFIX_CACHING,
    BLOCK_SIZE,
    PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights

BASE_MODEL_ADAPTER_ID = "__base_model__"


B = TypeVar("B", bound=Batch)


class Model(ABC):
    def __init__(
        self,
        model_id: str,
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        requires_padding: bool,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
        sliding_window: Optional[int] = None,
        speculate: Optional[int] = None,
        adapter_id: str = BASE_MODEL_ADAPTER_ID,
        support_chunking: bool = False,
    ):
        self.model_id = model_id
        self.model = model.eval()
        self.tokenizer = tokenizer

        # all_special_ids is not set correctly if the rust tokenizer is unpacked
        # TODO report this to transformers.
        other_special_ids = {
            id for id, token in tokenizer.added_tokens_decoder.items() if token.special
        }
        self.all_special_ids = set(tokenizer.all_special_ids)
        self.all_special_ids.update(other_special_ids)
        self.requires_padding = requires_padding
        self.dtype = dtype
        self.device = device
        self.rank = rank
        self.world_size = world_size
        self.sliding_window = sliding_window if sliding_window != -1 else None

        self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
            LayerAdapterWeights
        )
        self.loaded_adapters = set()
        self.static_adapter_id = adapter_id

        if speculate is None:
            speculate = get_speculate()
        self.speculate = speculate

        support_chunking = support_chunking and PREFILL_CHUNKING

        if speculate != 0 and support_chunking:
            log_master(
                logger.warning,
                "Prefill chunking does not support speculation yet. "
                "Prefill chunking will be turned off",
            )
            support_chunking = False
        if (
            ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
            and support_chunking
        ):
            log_master(
                logger.warning,
                "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
            )
            support_chunking = False

        log_master(logger.info, f"Using prefill chunking = {support_chunking}")

        self.support_chunking = support_chunking
        set_support_chunking(support_chunking)

        self.has_position_ids = (
            inspect.signature(model.forward).parameters.get("position_ids", None)
            is not None
        )

        self.check_initialized()

    @property
    def info(self) -> InfoResponse:
        if self.requires_padding and self.sliding_window is not None:
            raise NotImplementedError("sliding_window is not implemented with padding")

        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
            window_size=self.sliding_window,
            speculate=self.speculate,
            support_chunking=self.support_chunking,
            use_prefix_caching=PREFIX_CACHING,
            attention_impl=ATTENTION,
            block_size=BLOCK_SIZE,
        )

    @property
    @abstractmethod
    def batch_type(self) -> Type[B]:
        raise NotImplementedError

    @abstractmethod
    def generate_token(
        self, batch: B
    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
        raise NotImplementedError

    def warmup(
        self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
    ) -> Tuple[Optional[int], int, int]:
        self.generate_token(batch)
        total = sum(len(i) for i in batch.input_ids)
        if max_total_tokens is None:
            max_total_tokens = total

        if max_input_tokens is None:
            max_input_tokens = max_total_tokens - 1
        return None, max_input_tokens, max_total_tokens

    def decode_token(
        self,
        all_input_ids: List[int],
        prefix_offset: int = 0,
        read_offset: int = 0,
        skip_special_tokens: bool = False,
    ) -> Tuple[str, int, int]:
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""

        # The prefix text is necessary only to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        prefix_text = self.tokenizer.decode(
            all_input_ids[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
        )
        new_text = self.tokenizer.decode(
            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
        )

        if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            # If it's in the middle, it's probably a real invalid id generated
            # by the model
            new_text = new_text[len(prefix_text) :]
            return new_text, read_offset, len(all_input_ids)
        else:
            return "", prefix_offset, read_offset

    def check_initialized(self):
        uninitialized_parameters = []
        for n, p in self.model.named_parameters():
            if p.data.device == torch.device("meta"):
                uninitialized_parameters.append(n)
        if uninitialized_parameters:
            raise RuntimeError(
                f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
            )