From a352563ee03ba9383992940aac8db5e13ab4f681 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 11 Apr 2024 11:34:25 +0000 Subject: [PATCH] Style. --- server/text_generation_server/utils/layers.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 312a4482..cace1084 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -181,6 +181,7 @@ class EETQLinear(nn.Module): output = output + self.bias if self.bias is not None else output return output + def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): device = weight.device # weight, scale = quant_weights(weight, torch.int8, False) @@ -197,6 +198,7 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): scale = scale.float().reciprocal() return qweight, scale + class Fp8Linear(nn.Module): def __init__( self, @@ -206,14 +208,22 @@ class Fp8Linear(nn.Module): super().__init__() self.dtype = weight.dtype self.qweight, self.scale = fp8_quantize(weight) + self.bias = bias.cuda(device) if bias is not None else None def forward(self, input: torch.Tensor) -> torch.Tensor: qinput, scale = fp8_quantize(input) - output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype, - scale_a=scale , scale_b=self.scale, bias=self.bias) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) return output + class Linear8bitLt(nn.Module): def __init__( self,