This commit is contained in:
Mohit Sharma 2024-12-13 16:10:00 +00:00
parent ea7f4082c4
commit 1fa9ca2f16

View File

@ -395,6 +395,13 @@ class Fp8Linear(torch.nn.Module):
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)
batched = False
input_shape = input.shape
if input.dim() == 3:
input = input.view(-1, input_shape[-1])
batched = True
qinput, scale = fp8_quantize(
input,
self.input_scale,
@ -438,6 +445,9 @@ class Fp8Linear(torch.nn.Module):
output = output.to(dtype=self.dtype)
if batched:
output = output.view(*input_shape[:-1], output.shape[-1])
return output