import sys
from typing import Optional

import typer
from loguru import logger


app = typer.Typer()


@app.command()
def serve(
    model_id: str,
    revision: Optional[str] = None,
    sharded: bool = False,
    trust_remote_code: bool = None,
    uds_path: str = "/tmp/text-generation-server",
    logger_level: str = "INFO",
    json_output: bool = False,
    otlp_endpoint: Optional[str] = None,
    otlp_service_name: str = "text-generation-inference.server",
    max_input_tokens: Optional[int] = None,
):
    """This is the main entry-point for the server CLI.

    Args:
        model_id (`str`):
            The *model_id* of a model on the HuggingFace hub or the path to a local model.
        revision (`Optional[str]`, defaults to `None`):
            The revision of the model on the HuggingFace hub.
        sharded (`bool`):
            Whether the model must be sharded or not. Kept for compatibility with the
            text-generation-launcher, but must be set to False.
        trust-remote-code (`bool`):
            Kept for compatibility with text-generation-launcher. Ignored.
        uds_path (`Union[Path, str]`):
            The local path on which the server will expose its google RPC services.
        logger_level (`str`):
            The server logger level. Defaults to *INFO*.
        json_output (`bool`):
            Use JSON format for log serialization.
        otlp_endpoint (`Optional[str]`, defaults to `None`):
            The Open Telemetry endpoint to use.
        otlp_service_name (`Optional[str]`, defaults to `None`):
            The name to use when pushing data to the Open Telemetry endpoint.
        max_input_tokens (`Optional[int]`, defaults to `None`):
            The maximum number of input tokens each request should contain.
    """
    if sharded:
        raise ValueError("Sharding is not supported.")
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation_server",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

    if trust_remote_code is not None:
        logger.warning(
            "'trust_remote_code' argument is not supported and will be ignored."
        )

    # Import here after the logger is added to log potential import exceptions
    from .server import serve

    serve(model_id, revision, uds_path)


@app.command()
def download_weights(
    model_id: str,
    revision: Optional[str] = None,
    logger_level: str = "INFO",
    json_output: bool = False,
    auto_convert: Optional[bool] = None,
    extension: Optional[str] = None,
    trust_remote_code: Optional[bool] = None,
    merge_lora: Optional[bool] = None,
):
    """Download the model weights.

    This command will be called by text-generation-launcher before serving the model.
    """
    # Remove default handler
    logger.remove()
    logger.add(
        sys.stdout,
        format="{message}",
        filter="text_generation_server",
        level=logger_level,
        serialize=json_output,
        backtrace=True,
        diagnose=False,
    )

    if extension is not None:
        logger.warning("'extension' argument is not supported and will be ignored.")
    if trust_remote_code is not None:
        logger.warning(
            "'trust_remote_code' argument is not supported and will be ignored."
        )
    if auto_convert is not None:
        logger.warning("'auto_convert' argument is not supported and will be ignored.")
    if merge_lora is not None:
        logger.warning("'merge_lora' argument is not supported and will be ignored.")

    # Import here after the logger is added to log potential import exceptions
    from .model import fetch_model

    fetch_model(model_id, revision)