# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.

import torch

from typing import Optional, Type

from transformers import PreTrainedTokenizerBase

from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2


class BloomCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "CausalLMBatch":
        batch = super().from_pb(
            pb=pb,
            tokenizer=tokenizer,
            dtype=dtype,
            device=device,
        )
        batch.keys_head_dim_last = False
        return batch


class BLOOM(CausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        speculator: Optional[str] = None,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        super(BLOOM, self).__init__(
            model_id=model_id,
            revision=revision,
            speculator=speculator,
            dtype=dtype,
            trust_remote_code=trust_remote_code,
        )

    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        return BloomCausalLMBatch