text-generation-inference/backends/gaudi/server/text_generation_server/layers/conv.py
Baptiste Colle 683ff53fa3
Add Gaudi Backend (#3055)
* wip(gaudi): import server and dockerfile from tgi-gaudi fork

* feat(gaudi): new gaudi backend working

* fix: fix style

* fix prehooks issues

* fix(gaudi): refactor server and implement requested changes
2025-02-28 12:14:58 +01:00

42 lines
1.1 KiB
Python

from accelerate import init_empty_weights
import torch
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = torch.nn.Parameter(bias)
return conv2d
@classmethod
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = None
return conv2d
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias