from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional


if FLASH_DECODING:

    @dataclass
    class Seqlen:
        input_lengths: torch.Tensor
        cu_seqlen_q: Optional[torch.Tensor]
        cu_seqlen_k: Optional[torch.Tensor]

        def __init__(self, input_lengths):
            self.input_lengths = input_lengths
            device = self.input_lengths.device
            shape = self.input_lengths.shape
            cu_seqlen_q = torch.arange(
                shape[0] + 1,
                device=device,
                dtype=torch.int32,
            )
            cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
            # cuda graphs don't like this and this is necessary to clamp within mistral
            # Although FA2 might not want the clamping
            # cu_seqlen_k[0] = 0
            torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])

            self.cu_seqlen_q = cu_seqlen_q
            self.cu_seqlen_k = cu_seqlen_k

        def clamp(self, max):
            # Flash decoding doesn't need to clamp
            return self

else:

    @dataclass
    class Seqlen:
        input_lengths: torch.Tensor

        def clamp(self, max):
            return Seqlen(torch.clamp(self.input_lengths, max=max))