import torch
from torch import nn
from typing import Tuple, Optional
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.layers.linear import FastLinear
from text_generation_server.layers.tensor_parallel import (
    TensorParallelHead,
    TensorParallelColumnLinear,
)


class ResBlock(torch.nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        self.linear = FastLinear.load(
            config, prefix=f"{prefix}.linear", weights=weights, bias=True
        )
        self.act = torch.nn.SiLU()

    def forward(self, x):
        return x + self.act(self.linear(x))


class MedusaModel(torch.nn.Module):
    def __init__(self, config, medusa_config, weights):
        super().__init__()
        self.heads = torch.nn.ModuleList(
            [
                MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
                for i in range(get_speculate())
            ]
        )

    def forward(self, x):
        if not self.heads:
            return None
        speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
        return speculative_logits


class MedusaHead(torch.nn.Module):
    def __init__(self, config, medusa_config, prefix, weights):
        super().__init__()
        self.blocks = torch.nn.ModuleList(
            [
                ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
                for i in range(medusa_config["medusa_num_layers"])
            ]
        )
        n = len(self.blocks)
        self.out = FastLinear.load(
            config, prefix=f"{prefix}.{n}", weights=weights, bias=False
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.out(x)
        return x


class MedusaHeadV1(nn.Module):
    def __init__(self, lm_head, medusa):
        super().__init__()
        self.lm_head = lm_head
        self.medusa = medusa

    @staticmethod
    def load(config, prefix: str, weights):
        from pathlib import Path
        from safetensors import safe_open
        import json

        speculator = config.speculator

        path = speculator["path"]
        medusa_config = str(Path(path) / "config.json")

        for fname in speculator["model_paths"]:
            filename = str(Path(path) / fname)

            with open(medusa_config, "r") as f:
                medusa_config = json.load(f)
            routing = weights.routing
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing and routing[k] != filename:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename

        medusa = MedusaModel(config, medusa_config, weights)
        lm_head = TensorParallelHead.load(config, prefix, weights)
        return MedusaHeadV1(lm_head, medusa)

    def forward(
        self, input: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        logits = self.lm_head(input)
        # If we have too many tokens, we skip speculative logits
        if input.shape[0] > 128:
            return logits, None

        speculative_logits = self.medusa(input)
        return logits, speculative_logits


class MedusaHeadV2(nn.Module):
    def __init__(self, config, prefix, weights):
        super().__init__()
        from pathlib import Path
        from safetensors import safe_open
        import json

        speculator_path = config.speculator["path"]

        medusa_config = str(Path(speculator_path) / "config.json")
        filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")

        with open(medusa_config, "r") as f:
            medusa_config = json.load(f)
        routing = weights.routing
        with safe_open(filename, framework="pytorch") as f:
            for k in f.keys():
                if k in routing and routing[k] != filename:
                    raise RuntimeError(
                        f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                    )
                routing[k] = filename

        self.n_medusa_heads = get_speculate()

        assert medusa_config["medusa_num_layers"] == 1
        self.linear = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
            dim=0,
            weights=weights,
            bias=True,
        )
        self.process_group = weights.process_group
        self.world_size = self.process_group.size()
        self.rank = self.process_group.rank()

        self.act = torch.nn.SiLU()

        self.lm_head = TensorParallelHead.load(config, prefix, weights)

    def forward(self, x):
        # If we have too many tokens, we skip speculative logits
        if x.shape[0] > 128:
            logits = self.lm_head(x)
            return logits, None

        size = x.shape[-1]
        block_size = (size + self.world_size - 1) // self.world_size
        start = self.rank * block_size
        stop = (self.rank + 1) * block_size

        x_block = x[:, start:stop]

        # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
        medusa_res = self.act(self.linear(x)).reshape(
            *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
        )

        # Apply all residual medusa heads
        output = x[:, start:stop].unsqueeze(-2) + medusa_res

        # Gather medusa heads
        world_output = [
            torch.empty_like(output) for _ in range(self.process_group.size())
        ]
        torch.distributed.all_gather(world_output, output, group=self.process_group)
        world_output = torch.cat(world_output, dim=-1)

        # Stack x and medusa residual x
        stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)

        # Compute lm head on x + medusa residual x
        logits = self.lm_head(stacked_x)

        # Finally, split logits from speculative logits
        logits, speculative_logits = torch.split(
            logits, [1, self.n_medusa_heads], dim=-2
        )
        # Squeeze added dimension
        logits = logits.squeeze(-2)

        return logits, speculative_logits