import os
import tempfile

import pytest

import huggingface_hub.constants

import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
    weight_hub_files,
    download_weights,
    weight_files,
    EntryNotFoundError,
    LocalEntryNotFoundError,
    RevisionNotFoundError,
)


@pytest.fixture()
def offline():
    current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
    text_generation_server.utils.hub.HF_HUB_OFFLINE = True
    yield "offline"
    text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value


@pytest.fixture()
def fresh_cache():
    with tempfile.TemporaryDirectory() as d:
        current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
        os.environ["HUGGINGFACE_HUB_CACHE"] = d
        yield
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
        os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value


@pytest.fixture()
def prefetched():
    model_id = "bert-base-uncased"
    huggingface_hub.snapshot_download(
        repo_id=model_id,
        revision="main",
        local_files_only=False,
        repo_type="model",
        allow_patterns=["*.safetensors"],
    )
    yield model_id


def test_weight_hub_files_offline_error(offline, fresh_cache):
    # If the model is not prefetched then it will raise an error
    with pytest.raises(EntryNotFoundError):
        weight_hub_files("gpt2")


def test_weight_hub_files_offline_ok(prefetched, offline):
    # If the model is prefetched then we should be able to get the weight files from local cache
    filenames = weight_hub_files(prefetched)
    root = None
    assert len(filenames) == 1
    for f in filenames:
        curroot, filename = os.path.split(f)
        if root is None:
            root = curroot
        else:
            assert root == curroot
        assert filename == "model.safetensors"


def test_weight_hub_files():
    filenames = weight_hub_files("bigscience/bloom-560m")
    assert filenames == ["model.safetensors"]


def test_weight_hub_files_llm():
    filenames = weight_hub_files("bigscience/bloom")
    assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]


def test_weight_hub_files_empty():
    with pytest.raises(EntryNotFoundError):
        weight_hub_files("bigscience/bloom", extension=".errors")


def test_download_weights():
    model_id = "bigscience/bloom-560m"
    filenames = weight_hub_files(model_id)
    files = download_weights(filenames, model_id)
    local_files = weight_files("bigscience/bloom-560m")
    assert files == local_files


def test_weight_files_revision_error():
    with pytest.raises(RevisionNotFoundError):
        weight_files("bigscience/bloom-560m", revision="error")


def test_weight_files_not_cached_error(fresh_cache):
    with pytest.raises(LocalEntryNotFoundError):
        weight_files("bert-base-uncased")