text-generation-inference/server/text_generation_server/models/custom_modeling/linear.py

23 lines
610 B
Python
Raw Normal View History

2023-03-31 11:17:45 +00:00
import torch
from torch import nn
class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
def transpose_weight(self):
self.weight = nn.Parameter(self.weight.T)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)