import torch
from loguru import logger


def is_xpu_available():
    try:
        import intel_extension_for_pytorch
    except ImportError:
        return False

    return hasattr(torch, "xpu") and torch.xpu.is_available()


def get_cuda_free_memory(device, memory_fraction):
    total_free_memory, _ = torch.cuda.mem_get_info(device)
    total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
    free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
    return free_memory


def get_xpu_free_memory(device, memory_fraction):
    total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
    free_memory = int(total_gpu_memory * 0.5)
    return free_memory


SYSTEM = None
if torch.version.hip is not None:
    SYSTEM = "rocm"
    empty_cache = torch.cuda.empty_cache
    synchronize = torch.cuda.synchronize
    get_free_memory = get_cuda_free_memory
elif torch.version.cuda is not None and torch.cuda.is_available():
    SYSTEM = "cuda"
    empty_cache = torch.cuda.empty_cache
    synchronize = torch.cuda.synchronize
    get_free_memory = get_cuda_free_memory
elif is_xpu_available():
    SYSTEM = "xpu"
    empty_cache = torch.xpu.empty_cache
    synchronize = torch.xpu.synchronize
    get_free_memory = get_xpu_free_memory
else:
    SYSTEM = "cpu"

    def noop(*args, **kwargs):
        pass

    empty_cache = noop
    synchronize = noop
    get_free_memory = noop
logger.info(f"Detected system {SYSTEM}")