import torch
from dataclasses import dataclass
from typing import List, Optional, Type

from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch


@dataclass
class StarCoderCausalLMBatch(CausalLMBatch):
    past_key_values: Optional[List[torch.Tensor]]

    def detach_kv_cache(self):
        past_keys = []
        past_values = []
        last_dim = int(self.past_key_values[0].size(dim=-1) / 2)
        for key_value in self.past_key_values:
            past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
            past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
        del self.past_key_values

        return past_keys, past_values

    def attach_kv_cache(self, past_keys, past_values):
        self.past_key_values = [
            torch.cat((key, value), dim=-1)
            for key, value in zip(past_keys, past_values)
        ]


class StarCoder(CausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
    ):

        super(StarCoder, self).__init__(
            model_id=model_id,
            revision=revision,
            dtype=dtype,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return StarCoderCausalLMBatch