from text_generation_server.generator import NeuronGenerator
from text_generation_server.pb.generate_pb2 import (
    Batch,
    NextTokenChooserParameters,
    Request,
    StoppingCriteriaParameters,
)


def create_request(
    id: int,
    inputs: str,
    truncate: int = 0,
    max_new_tokens: int = 20,
    do_sample: bool = False,
    top_k: int = 50,
    top_p: float = 0.9,
    temperature: float = 1.0,
    seed: int = 42,
    repetition_penalty: float = 1.0,
):
    parameters = NextTokenChooserParameters(
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
        seed=seed,
        repetition_penalty=repetition_penalty,
    )
    stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
    return Request(
        id=id,
        inputs=inputs,
        truncate=truncate,
        parameters=parameters,
        stopping_parameters=stopping_parameters,
    )


def check_prefill(
    input_text,
    expected_token_id,
    expected_token_text,
    do_sample,
    batch_size,
    model_path,
):
    """Verify that a prefill for a single request generates the expected output."""
    generator = NeuronGenerator.from_pretrained(model_path)
    assert generator.model.batch_size >= batch_size
    requests = []
    max_new_tokens = 20
    for i in range(batch_size):
        requests.append(
            create_request(
                id=0,
                inputs=input_text,
                do_sample=do_sample,
                max_new_tokens=max_new_tokens,
            )
        )
    # Let's be pessimistic when estimating max_tokens
    batch_size * (len(input_text) + max_new_tokens)
    max_length = generator.model.max_length
    batch = Batch(
        id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
    )
    generations, next_batch = generator.prefill(batch)
    assert next_batch.size == batch_size
    # Whatever was passed as max_tokens, the server will correct it
    # because of static batching
    assert next_batch.max_tokens == batch_size * max_length
    assert len(generations) == batch_size
    for g in generations:
        tokens = g.tokens
        assert tokens.ids == [expected_token_id]
        assert tokens.texts == [expected_token_text]


def check_decode_single(
    input_text, max_new_tokens, generated_text, do_sample, model_path
):
    """Verify that a decoding for a single request generates the expected output."""
    generator = NeuronGenerator.from_pretrained(model_path)
    request = create_request(
        id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
    )
    max_length = generator.model.max_length
    batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
    generations, next_batch = generator.prefill(batch)
    # We already generated one token: call decode max_new_tokens - 1 times
    for _ in range(max_new_tokens - 1):
        assert next_batch.size == 1
        assert next_batch.max_tokens == max_length
        assert len(generations) == 1
        assert len(generations[0].tokens.ids) == 1
        generations, next_batch = generator.decode([next_batch])
    assert next_batch is None
    assert len(generations) == 1
    output = generations[0].generated_text
    assert output.generated_tokens == max_new_tokens
    assert output.finish_reason == 0
    assert output.text == generated_text


def check_decode_multiple(model_path):
    """Verify that two requests added to the batch at different generation steps
    generate the same outputs (continuous batching).
    """
    generator = NeuronGenerator.from_pretrained(model_path)
    assert generator.model.batch_size > 1
    input_text = "Once upon a time"
    max_new_tokens = 20
    # Prefill a single request, remembering the generated token
    tokens = {0: [], 1: []}
    request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
    max_length = generator.model.max_length
    batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
    generations, next_batch = generator.prefill(batch)
    assert next_batch.size == 1
    assert len(generations) == 1
    g = generations[0]
    tokens[g.request_id].append(g.tokens.ids[0])
    assert len(tokens[0]) == 1
    # Decode a few tokens
    gen_tokens = 4
    for _ in range(gen_tokens - 1):
        generations, next_batch = generator.decode([next_batch])
        assert len(generations) == 1
        g = generations[0]
        tokens[g.request_id].append(g.tokens.ids[0])
    assert len(tokens[0]) == gen_tokens
    assert next_batch.size == 1
    # Add a second request
    request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
    batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
    generations, next_batch_1 = generator.prefill(batch)
    assert next_batch_1.size == 1
    # We should have generated only a single token
    assert len(generations) == 1
    g = generations[0]
    tokens[g.request_id].append(g.tokens.ids[0])
    assert len(tokens[0]) == gen_tokens
    assert len(tokens[1]) == 1
    # Decode more tokens until we reach the maximum for the first request
    batches = [next_batch, next_batch_1]
    for _ in range(max_new_tokens - gen_tokens):
        generations, next_batch = generator.decode(batches)
        for g in generations:
            tokens[g.request_id].append(g.tokens.ids[0])
        batches = [next_batch]
    # Verify we now only have one pending request
    assert next_batch.size == 1
    assert len(tokens[0]) == max_new_tokens
    assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
    # Verify we have the output for the first request
    for g in generations:
        if g.request_id == 0:
            output = g.generated_text
            assert output.text != ""
            assert output.generated_tokens == max_new_tokens
            generated_text = output.text
    # Continue decoding until the end of the second request
    for _ in range(gen_tokens - 1):
        generations, next_batch = generator.decode([next_batch])
        assert len(generations) == 1
        g = generations[0]
        tokens[g.request_id].append(g.tokens.ids[0])
    assert next_batch is None
    output = generations[0].generated_text
    assert output.generated_tokens == max_new_tokens
    assert tokens[0] == tokens[1]
    assert output.text == generated_text