diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ce3bddf9..8ce5dcba 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -181,6 +181,22 @@ class EETQLinear(nn.Module): output = output + self.bias if self.bias is not None else output return output +def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + class Fp8Linear(nn.Module): def __init__( self, @@ -188,34 +204,21 @@ class Fp8Linear(nn.Module): bias, ) -> None: super().__init__() - device = weight.device - # weight, scale = quant_weights(weight, torch.int8, False) - finfo = torch.finfo(weight.dtype) - qdtype = torch.float8_e4m3fn - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - x_scl_sat = (weight * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm self.dtype = weight.dtype - self.qweight = x_scl_sat.to(qdtype).to(device=device) - self.scale = scale.float().reciprocal().to(device=device) + self.qweight, self.scale = fp8_quantize(weight) self.bias = bias.cuda(device) if bias is not None else None def forward(self, input: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(input.dtype) - scale = finfo.max / input.abs().max().clamp(min=1e-12) - qinput = (input * scale).clamp(min=finfo.min, max=finfo.max) - - output, _ = torch._scaled_mm(qinput, self.qweight, out_dtype=torch.float16, - scale_a=scale , scale_b=self.scale) - output = output + self.bias if self.bias is not None else output + qinput, scale = fp8_quantize(input) + seqlen = qinput.shape[0] + if seqlen % 16 != 0: + missing = 16 - seqlen % 16 + qinput = F.pad(qinput, (0, 0, 0, missing), "constant", value=0) + output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype, + scale_a=scale , scale_b=self.scale, bias=self.bias) + output = output[:seqlen] return output - class Linear8bitLt(nn.Module): def __init__( self,