also quant weights with single scale

This commit is contained in:
OlivierDehaene 2024-07-22 18:49:10 +02:00
parent 3d0c7b85fe
commit 473f968a01
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -203,7 +203,7 @@ class Fp8Linear(torch.nn.Module):
@classmethod @classmethod
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight) qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls( return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
) )