text-generation-inference/server/text_generation_server/layers/eetq.py
Daniël de Kok f0ed76583c
Use eetq kernel from the hub (#3029)
* Use eetq kernel from the hub

* Fixing the CI.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-02-18 10:03:53 +01:00

48 lines
1.5 KiB
Python

from dataclasses import dataclass
import torch
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight
quantization_eetq = load_kernel(
module="quantization_eetq", repo_id="kernels-community/quantization-eetq"
)
@dataclass
class EETQWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu()
weight, scale = quantization_eetq.quant_weights(weight, torch.int8, False)
self.weight = weight.cuda(device)
self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = quantization_eetq.w8_a16_gemm(input, self.weight, self.scale)
output = output + self.bias if self.bias is not None else output
return output