import copy
import logging
import time
from abc import ABC
from enum import Enum
from typing import List, Optional, Tuple

import torch
from loguru import logger
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation import GenerationConfig

from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.generation import TokenSelector

from .model import get_export_kwargs_from_env
from .pb.generate_pb2 import (
    Batch,
    CachedBatch,
    FinishReason,
    GeneratedText,
    Generation,
    InfoResponse,
    Request,
    Tokens,
)


# Disable optimum-neuron warnings as it seems to block the server after a while
optimum_logger = logging.getLogger("optimum.neuron")
optimum_logger.setLevel("CRITICAL")


class Generator(ABC):
    """An abstract class to represent the workhorse behind TextGenerationService.

    Ideally, it should not rely on protobuf constructs, but in a first step it does.
    Implementations would typically need a model and a tokenizer to implement the Generator methods.
    """

    @property
    def info(self) -> InfoResponse:
        """This should simply return the expected InfoResponse"""
        raise NotImplementedError

    def warmup(self, batch: Batch) -> int:
        """Verify if the hardware can support the target load.

        Args:
            batch (`Batch`):
                A batch corresponding to the maximum number of concurrent requests.

        Return:
            The maximum number of tokens the model supports.
        """
        raise NotImplementedError

    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        """Prefill is called whenever new requests need to be added.

        When this method returns successfully, a decode method will follow
        with both the current and newly prefilled batch(es).

        Args:
            batch (`Batch`):
                A batch containing the new requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        raise NotImplementedError

    def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:
        """Decode after a prefill or another decode."""
        raise NotImplementedError

    def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
        """Remove requests that are not listed from the specified batch"""
        raise NotImplementedError

    def clear(self):
        """Remove all requests from the generator"""
        raise NotImplementedError

    @classmethod
    def from_pretrained(cls, model_id: str, revision: Optional[str]):
        """Factory method "a la transformers" """
        raise NotImplementedError


class Slot:
    """Represents a slot in a static batch"""

    class State(Enum):
        EMPTY = 0
        PAUSE = 1
        READY = 2

    def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
        self._id = id
        self._tokenizer = tokenizer
        self.clear()

    def clear(self):
        """Clear the slot and mark it as available."""
        self._state = Slot.State.EMPTY
        self._batch_id = None
        self._request_id = None
        self._inputs = ""
        self._truncate = 0
        self._generation_config = None
        self._tokens = []
        self._mask = torch.tensor([])
        self._selector = None
        self._generated_tokens = 0
        self._next_text_token_start = 0
        self._next_text_token_end = 0
        self._generated_text = ""
        self._next_text = ""

    @property
    def id(self) -> int:
        return self._id

    @property
    def state(self) -> "Slot.State":
        return self._state

    @property
    def batch_id(self) -> int:
        return self._batch_id

    @property
    def request_id(self) -> int:
        return self._request_id

    @property
    def cached_text(self) -> str:
        return self._inputs + self._generated_text

    @property
    def generation_config(self) -> GenerationConfig:
        return self._generation_config

    @property
    def generated_tokens(self) -> int:
        return self._generated_tokens

    def assign(
        self, batch_id: int, request: Request, generation_config: GenerationConfig
    ):
        """Assign a request to a slot.

        Args:
            request (`Request`):
                The request to be assigned. Contains the inputs and tokens selection parameters.
            generation_config (`transformers.GenerationConfig`):
                The base generation config (might be modified by the request generation parameters).
        """
        self._state = Slot.State.READY
        self._batch_id = batch_id
        self._request_id = request.id
        self._inputs = request.inputs
        if request.truncate:
            self._truncate = request.truncate
        self._generation_config = copy.deepcopy(generation_config)
        # Update generation config with request parameters
        self._generation_config.do_sample = request.parameters.do_sample
        if self._generation_config.do_sample:
            if request.parameters.temperature != 0:
                self._generation_config.temperature = request.parameters.temperature
            if request.parameters.top_k != 0:
                self._generation_config.top_k = request.parameters.top_k
            if request.parameters.top_p != 0:
                self._generation_config.top_p = request.parameters.top_p
            if request.parameters.typical_p != 0:
                self._generation_config.typical_p = request.parameters.typical_p
        if request.parameters.repetition_penalty != 0:
            self._generation_config.repetition_penalty = (
                request.parameters.repetition_penalty
            )
        self.seed = request.parameters.seed
        self._generation_config.max_new_tokens = (
            request.stopping_parameters.max_new_tokens
        )
        self._max_new_tokens = self._generation_config.max_new_tokens
        stop_strings = request.stopping_parameters.stop_sequences
        if stop_strings:
            self._generation_config.stop_strings = stop_strings

    def reset(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.LongTensor,
        selector: TokenSelector,
    ):
        """Reset the slot for the next generation.

        Args:
            input_ids: (`torch.LongTensor`):
                The new input_ids to use to generate the next token.
            attention_mask: (`torch.LongTensor`):
                The new attention_mask to use to generate the next token.
            selector: (`optimum.neuron.generation.TokenSelector`):
                An object implementing the updated token selection logic.
        """
        self._tokens = input_ids.clone()
        self._next_text_token_start = 0
        self._next_text_token_end = torch.numel(self._tokens)
        self._next_text = ""
        self._mask = attention_mask.clone()
        self._selector = selector

    def pause(self, reset_on_pause: bool):
        """Mark the current slot as paused for generation.

        Note that the KV cache for this slot will still be filled.
        """
        if reset_on_pause:
            # Drop the last token as it will be added back when resuming the slot
            self._generated_tokens -= 1
            # Since generated tokens are now part of the prefill, we need to reevaluate
            # max_new_tokens for the next generation
            self._generation_config.max_new_tokens = (
                self._max_new_tokens - self._generated_tokens
            )
        self._state = Slot.State.PAUSE

    def resume(self):
        """Mark the slot as ready for generation."""
        self._state = Slot.State.READY

    def _decode_next_tokens(
        self,
    ) -> str:
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
        # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
        # which decide to add a space or not depending on the surrounding ids.
        new_text = self._tokenizer.decode(
            self._tokens[self._next_text_token_start :], skip_special_tokens=False
        )
        if new_text.endswith("�"):
            # utf-8 char at the end means it's a potential unfinished byte sequence
            # from byte fallback tokenization.
            return ""

        # Compare the generated text with the one using only the tokens producing the last one
        last_text = self._tokenizer.decode(
            self._tokens[self._next_text_token_start : self._next_text_token_end],
            skip_special_tokens=False,
        )
        if len(new_text) == len(last_text):
            # Nothing new was actually generated
            return ""
        # Return the decoded text and store its token offsets
        self._next_text_token_start = self._next_text_token_end
        self._next_text_token_end = torch.numel(self._tokens)
        return new_text[len(last_text) :]

    def append(self, next_token: int) -> str:
        """Append a new generated token to this slot

        The new token is added to the list of generated tokens, which impacts
        directly the generated_text and stopped property.

        The new token is however not added immediately to the slot inputs: it will
        be added later on when it has effectively been used to produce the next token.

        Args:
            next_token (`int`):
                The newly generated token.

        Return:
            The corresponding decoded text (if any).
        """
        self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
        self._mask = torch.cat([self._mask, torch.LongTensor([1])])
        self._generated_tokens += 1
        next_text = self._decode_next_tokens()
        # Now that a new token has been generated, we can append the previous one to the generated text
        self._generated_text += self._next_text
        self._next_text = next_text
        return next_text

    def select(
        self, input_ids: torch.LongTensor, logits: torch.Tensor
    ) -> torch.LongTensor:
        """Select the next token from the candidate logits.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation (not used in all generation modes).
            logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
                The logits corresponding to the generated tokens.

        Return:
            `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
        """
        return self._selector.select(input_ids, logits)[0]

    @property
    def stopped(self) -> bool:
        # Transformers stopping criteria expects a batch of input ids
        input_ids = torch.unsqueeze(self._tokens, dim=0)
        return self._selector.stopping_criteria(input_ids, None)

    @property
    def generated_text(self) -> str:
        return self._generated_text + self._next_text

    @property
    def next_token(self) -> int:
        return None if len(self._tokens) == 0 else self._tokens[-1]

    @property
    def attention_mask(self) -> torch.LongTensor:
        return self._mask

    @property
    def max_token(self) -> int:
        return self._generation_config.max_length

    @property
    def max_new_tokens(self) -> int:
        # The current value of max_new_tokens: might be different of the target max_new_tokens
        # if the slot has been paused and resumed.
        return self._generation_config.max_new_tokens

    @property
    def truncate(self) -> int:
        return self._truncate


class NeuronGenerator(Generator):
    """A Generator for Neuron models."""

    def __init__(
        self,
        model: NeuronModelForCausalLM,
        tokenizer: PreTrainedTokenizerBase,
    ):
        self.model = model
        self.rebuild_cache_on_prefill = not self.model.continuous_batching
        # Specify padding and truncation options for decoder-only architecture
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.padding_side = "left"
        tokenizer.truncation_side = "left"
        self.tokenizer = tokenizer
        self.special_tokens = self.tokenizer.all_special_ids
        self.slots = [Slot(i, tokenizer) for i in range(self.model.batch_size)]
        self.batch_id = 0

    @property
    def info(self) -> InfoResponse:
        """Returns the expected InfoResponse."""
        dtype = getattr(self.model.config, "torch_dtype", "float32")
        return InfoResponse(
            requires_padding=True,
            dtype=str(dtype),
            device_type="xla",
        )

    def warmup(self, batch: Batch) -> int:
        """Verify if the hardware can support the target load.

        Args:
            batch (`Batch`):
                A batch corresponding to the maximum number of concurrent requests.

        Return:
            The maximum number of tokens the model supports.
        """
        # Just check that the warmup request parameters match the model capacity
        batch_size = self.model.batch_size
        if len(batch.requests) > batch_size:
            raise ValueError(
                f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI.  The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process.  The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
            )
        self.prefill(batch)
        self.clear()
        return self.model.batch_size * self.model.max_length

    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
        """Prefill new requests.

        Args:
            batch (`Batch`):
                A batch containing the new requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        slots = {state: [] for state in Slot.State}
        for slot in self.slots:
            slots[slot.state].append(slot)
        active_slots = slots[Slot.State.READY]
        empty_slots = slots[Slot.State.EMPTY]
        if len(empty_slots) < len(batch.requests):
            raise ValueError(
                f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
                f" Please align max_batch_size with the static batch size: {self.model.batch_size}."
            )
        # Assign each request to an empty slot
        logger.debug(
            f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)"
        )
        new_slots = []
        for request in batch.requests:
            slot = empty_slots.pop()
            slot.assign(self.batch_id, request, self.model.generation_config)
            new_slots.append(slot)
            logger.debug(
                f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
            )
        if self.rebuild_cache_on_prefill:
            # We will clear pending slots and prefill all slots
            prefill_slots = self.slots
            seq_ids = None
        else:
            # We only need to pass inputs for the new requests
            prefill_slots = new_slots
            seq_ids = torch.tensor([slot.id for slot in prefill_slots])
        # Reconstruct the full inputs (without padding) as seen by the model.
        # This comprises:
        # - the inputs for new requests,
        # - only when rebuilding the cache, the inputs and the generated text that has already
        # been cached (i.e. excluding the last generated token) for unfinished requests.
        inputs = []
        max_length = 0
        for slot in prefill_slots:
            inputs.append(slot.cached_text)
            # Apply truncation, making sure we fit into static dimensions
            if slot.truncate == 0:
                max_length = self.model.max_length
            elif slot.truncate > max_length and slot.truncate < self.model.max_length:
                max_length = slot.truncate
        # Tokenize with padding and truncation
        padded_inputs = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )
        input_ids = padded_inputs.input_ids
        attention_mask = padded_inputs.attention_mask
        # Pause previously active slots during generation
        next_tokens = []
        for slot in active_slots:
            slot.pause(reset_on_pause=self.rebuild_cache_on_prefill)
            if self.rebuild_cache_on_prefill:
                # The slot will be reset, so we need to store its next token
                next_tokens.append(slot.next_token)
        # Each slot must be reset with the padded inputs and masks
        for i, slot in enumerate(prefill_slots):
            if slot.state != slot.state.EMPTY:
                if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
                    # Apply per-request truncation
                    input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
                    attention_mask[i, : -slot.truncate] = 0
                slot_input_ids = input_ids[i : i + 1, :]
                # Padded input ids are also required to set logits processors and stopping criterias
                selector = TokenSelector.create(
                    slot_input_ids,
                    slot.generation_config,
                    self.model,
                    self.model.max_length,
                    tokenizer=self.tokenizer,
                    seed=slot.seed,
                )
                slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
                slot_attention_mask = attention_mask[i]
                slot.reset(slot_input_ids, slot_attention_mask, selector)
        # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
        # as they have already been generated and sent back in the last decode.
        model_inputs = self.model.prepare_inputs_for_prefill(
            input_ids, attention_mask, seq_ids
        )
        logits = self.model(**model_inputs)[0]
        generation, next_batch = self._generate_token(
            prefill_slots, self.batch_id, logits, input_ids
        )
        self.batch_id += 1
        # Reactivate previously active slots for the next decode
        for i, slot in enumerate(active_slots):
            slot.resume()
            if self.rebuild_cache_on_prefill:
                # Append back the next token
                slot.append(next_tokens[i])
        logger.debug("Model ready for decoding")
        if next_batch is not None:
            logger.debug(
                f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}"
            )
        return generation, next_batch

    def decode(
        self, batches: List[CachedBatch]
    ) -> Tuple[List[Generation], CachedBatch]:
        """Decode the specified prefilled requests.

        Args:
            batches (`List[CachedBatch]`):
                A list of previous batches containing the prefilled requests.

        Return:
            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
        """
        # batches contains a list composed of:
        # - the batch id returned by the last decode,
        # - the batch id(s) returned by the last prefill(s)
        # Batches are always concatenated during prefill, so we can
        # just carry on with decoding. We adopt the id of the first
        # batch in the list as our next batch id.
        next_batch_id = batches[0].id
        request_ids = []
        for batch in batches:
            request_ids += batch.request_ids
        cleared_request_ids = []
        for slot in self.slots:
            if slot.state == slot.State.READY and slot.request_id not in request_ids:
                cleared_request_ids.append(slot.request_id)
                slot.clear()
        if len(cleared_request_ids) > 0:
            logger.info(
                f"Clearing slot for requests {cleared_request_ids} as they are not requested."
            )
        active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
        if len(active_slots) < len(request_ids):
            raise ValueError(
                "Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
            )
        if self.model.continuous_batching:
            decode_slots = active_slots
            seq_ids = torch.tensor([slot.id for slot in decode_slots])
        else:
            decode_slots = self.slots
            seq_ids = None
        # Reconstruct input_ids and attention_mask from decode slots
        n_slots = len(decode_slots)
        input_ids = torch.full(
            [n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
        )
        max_length = 0
        for slot in decode_slots:
            max_length = max(max_length, slot.attention_mask.size(-1))
        attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
        for i, slot in enumerate(decode_slots):
            if slot.state != Slot.State.EMPTY:
                # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
                input_ids[i, 0] = slot.next_token
                attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
        model_inputs = self.model.prepare_inputs_for_decode(
            input_ids, attention_mask, seq_ids
        )
        logits = self.model(**model_inputs)[0]
        return self._generate_token(decode_slots, next_batch_id, logits, input_ids)

    def _generate_token(
        self,
        slots: List[Slot],
        next_batch_id: int,
        logits: torch.Tensor,
        input_ids: torch.LongTensor,
    ) -> Tuple[List[Generation], CachedBatch]:
        generations = []
        active_slots = False
        for i, slot in enumerate(slots):
            if slot.state != Slot.State.READY:
                continue
            request_id = slot.request_id
            next_token_logits = logits[i : i + 1, -1, :]
            slot_input_ids = input_ids[i : i + 1, :]
            next_token = slot.select(slot_input_ids, next_token_logits)
            next_token_text = slot.append(next_token)
            generated_text = None
            finish_reason = None
            if next_token == self.tokenizer.eos_token_id:
                finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
            elif slot.stopped:
                if slot.generated_tokens == slot.max_new_tokens:
                    finish_reason = FinishReason.FINISH_REASON_LENGTH
                else:
                    finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
            if finish_reason is not None:
                # We must include the generated text for each finished sequence in the response
                generated_text = GeneratedText(
                    text=slot.generated_text,
                    generated_tokens=slot.generated_tokens,
                    finish_reason=finish_reason,
                )
                logger.debug(
                    f"Decode complete for request {request_id} with {slot.generated_tokens} tokens"
                )
                # mark the slot as available
                slot.clear()
            else:
                active_slots = True
            generations.append(
                Generation(
                    request_id=request_id,
                    prefill_tokens=None,
                    tokens=Tokens(
                        ids=[next_token],
                        logprobs=[0],
                        texts=[next_token_text],
                        is_special=[next_token in self.special_tokens],
                    ),
                    generated_text=generated_text,
                )
            )
        batch = None
        if active_slots:
            # Whatever initial batch these requests came from, we always return all pending requests in a single batch
            request_ids = [
                slot.request_id for slot in self.slots if slot.state == Slot.State.READY
            ]
            batch = self._cached_batch(next_batch_id, request_ids)
        else:
            logger.debug("No more pending requests")
        return generations, batch

    def _cached_batch(self, batch_id: int, request_ids: List):
        size = len(request_ids)
        max_tokens = size * self.model.max_length
        return CachedBatch(
            id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
        )

    def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
        """Remove requests that are not listed from the specified batch

        Args:
            batch_id (`int`):
                The id of a cached batch.
            keep_ids(`List[int]`):
                The list of requests that must be kept.

        Return:
            A `CachedBatch` containing the pending requests.
        """
        keep_slot_ids = [
            slot.id for slot in self.slots if slot.request_id in keep_request_ids
        ]
        self._clear(keep_slot_ids)
        return self._cached_batch(batch_id, keep_request_ids)

    def clear(self, batch_id: Optional[int] = None):
        """Remove a subset or all requests from the generator"""
        keep_ids = []
        if batch_id is not None:
            keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
        return self._clear(keep_ids)

    def _clear(self, keep_slot_ids: List):
        for slot in self.slots:
            if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
                logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
                slot.clear()

    @classmethod
    def from_pretrained(cls, model_id: str, revision: str = None):
        """Instantiate a NeuronGenerator.

        Args:
            model_id (`str`):
                A hub model id or the path to a local model. This path must also contain a Tokenizer.
            revision (`Optional[str]`, defaults to `None`):
                The revision of the model on the HuggingFace hub.

        Returns:
            A NeuronGenerator.
        """
        config = AutoConfig.from_pretrained(model_id)
        neuron_config = getattr(config, "neuron", None)
        start = time.time()
        if neuron_config is None:
            export_kwargs = get_export_kwargs_from_env()
            logger.info(f"Exporting model to neuron with config: {export_kwargs}.")
            model = NeuronModelForCausalLM.from_pretrained(
                model_id,
                revision=revision,
                low_cpu_mem_usage=True,
                export=True,
                **export_kwargs,
            )
        else:
            logger.info(
                "Loading model on neuron devices (this can take a few minutes)."
            )
            model = NeuronModelForCausalLM.from_pretrained(
                model_id, low_cpu_mem_usage=True, revision=revision
            )
        end = time.time()
        logger.info(f"Model successfully loaded in {end - start:.2f} s.")
        tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
        return cls(model, tokenizer)