mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-06 17:32:09 +00:00
add fix
This commit is contained in:
parent
ea7f4082c4
commit
1fa9ca2f16
@ -395,6 +395,13 @@ class Fp8Linear(torch.nn.Module):
|
||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||
)
|
||||
|
||||
batched = False
|
||||
input_shape = input.shape
|
||||
|
||||
if input.dim() == 3:
|
||||
input = input.view(-1, input_shape[-1])
|
||||
batched = True
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
self.input_scale,
|
||||
@ -438,6 +445,9 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
output = output.to(dtype=self.dtype)
|
||||
|
||||
if batched:
|
||||
output = output.view(*input_shape[:-1], output.shape[-1])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user