diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9cf5c80f..f3e51672 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -212,6 +212,8 @@ class Fp8Linear(nn.Module): self.bias = bias if bias is not None else None def forward(self, input: torch.Tensor) -> torch.Tensor: + if (bsz := input.shape[0]) & 15: + input = F.pad(input,(0, 0, 0, 16 - (bsz & 15))) qinput, scale = fp8_quantize(input) output, _ = torch._scaled_mm( qinput, @@ -221,7 +223,7 @@ class Fp8Linear(nn.Module): scale_b=self.scale, bias=self.bias, ) - return output + return output[:bsz] class Linear8bitLt(nn.Module):