import torch
import json
from typing import Tuple, Optional
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead


class SpeculativeHead(torch.nn.Module):
    def __init__(self, lm_head, speculator):
        super().__init__()
        self.head = lm_head
        self.speculator = speculator

    @staticmethod
    def load(config, prefix: str, weights):
        speculator = config.speculator
        if speculator:
            speculator_path = config.speculator["path"]
            speculator_config = str(speculator_path / "config.json")

            with open(speculator_config, "r") as f:
                speculator_config = json.load(f)

            config.speculator_config = speculator_config
            try:
                architecture = speculator_config["architectures"][0]

                if architecture == "MLPSpeculatorPreTrainedModel":
                    speculator = MLPSpeculatorHead.load(config, prefix, weights)
                else:
                    speculator = None
            except KeyError:
                try:
                    speculator = MedusaHeadV1.load(config, prefix, weights)
                except Exception:
                    speculator = MedusaHeadV2(config, prefix, weights)
            lm_head = None
        else:
            lm_head = TensorParallelHead.load(config, prefix, weights)
            speculator = None
        return SpeculativeHead(lm_head, speculator)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if self.speculator is not None:
            return self.speculator(input)

        assert self.head is not None
        logits = self.head(input)
        return logits, None