mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-06 17:32:09 +00:00
add fix
This commit is contained in:
parent
ea7f4082c4
commit
1fa9ca2f16
@ -395,6 +395,13 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batched = False
|
||||||
|
input_shape = input.shape
|
||||||
|
|
||||||
|
if input.dim() == 3:
|
||||||
|
input = input.view(-1, input_shape[-1])
|
||||||
|
batched = True
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input,
|
input,
|
||||||
self.input_scale,
|
self.input_scale,
|
||||||
@ -438,6 +445,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
output = output.to(dtype=self.dtype)
|
output = output.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
if batched:
|
||||||
|
output = output.view(*input_shape[:-1], output.shape[-1])
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user