This commit is contained in:
Nicolas Patry 2024-04-11 11:34:25 +00:00
parent 66195d832c
commit a352563ee0

View File

@ -181,6 +181,7 @@ class EETQLinear(nn.Module):
output = output + self.bias if self.bias is not None else output
return output
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
device = weight.device
# weight, scale = quant_weights(weight, torch.int8, False)
@ -197,6 +198,7 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
scale = scale.float().reciprocal()
return qweight, scale
class Fp8Linear(nn.Module):
def __init__(
self,
@ -206,14 +208,22 @@ class Fp8Linear(nn.Module):
super().__init__()
self.dtype = weight.dtype
self.qweight, self.scale = fp8_quantize(weight)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
qinput, scale = fp8_quantize(input)
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
scale_a=scale , scale_b=self.scale, bias=self.bias)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
)
return output
class Linear8bitLt(nn.Module):
def __init__(
self,