diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 1e5c8b3d..db305fdc 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -395,6 +395,13 @@ class Fp8Linear(torch.nn.Module): qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias ) + batched = False + input_shape = input.shape + + if input.dim() == 3: + input = input.view(-1, input_shape[-1]) + batched = True + qinput, scale = fp8_quantize( input, self.input_scale, @@ -438,6 +445,9 @@ class Fp8Linear(torch.nn.Module): output = output.to(dtype=self.dtype) + if batched: + output = output.view(*input_shape[:-1], output.shape[-1]) + return output