from accelerate import init_empty_weights
import torch


@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )

    conv2d.weight = torch.nn.Parameter(weight)
    conv2d.bias = torch.nn.Parameter(bias)
    return conv2d


@classmethod
def load_conv2d_no_bias(
    cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        conv2d = cls(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
        )

    conv2d.weight = torch.nn.Parameter(weight)
    conv2d.bias = None
    return conv2d


torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias