import os
import sys
import typer

from pathlib import Path
from loguru import logger

from text_generation import server, utils

app = typer.Typer()


@app.command()
def serve(
    model_name: str,
    sharded: bool = False,
    quantize: bool = False,
    uds_path: Path = "/tmp/text-generation",
    logger_level: str = "INFO",
    json_output: bool = False,
):
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )
    if sharded:
        assert (
            os.getenv("RANK", None) is not None
        ), "RANK must be set when sharded is True"
        assert (
            os.getenv("WORLD_SIZE", None) is not None
        ), "WORLD_SIZE must be set when sharded is True"
        assert (
            os.getenv("MASTER_ADDR", None) is not None
        ), "MASTER_ADDR must be set when sharded is True"
        assert (
            os.getenv("MASTER_PORT", None) is not None
        ), "MASTER_PORT must be set when sharded is True"

    server.serve(model_name, sharded, quantize, uds_path)


@app.command()
def download_weights(
    model_name: str,
    extension: str = ".safetensors",
):
    utils.download_weights(model_name, extension)


if __name__ == "__main__":
    app()