fix(l4): fix fp8 logic on l4

This commit is contained in:
OlivierDehaene 2024-07-22 18:45:26 +02:00
parent 6aeb669072
commit 3d0c7b85fe
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear return Fp8Linear
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): def fp8_quantize(
if FBGEMM_DYN_AVAILABLE: weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
):
if FBGEMM_DYN_AVAILABLE and not scalar:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
) )
@ -232,7 +234,7 @@ class Fp8Linear(torch.nn.Module):
) )
return y.to(self.dtype) return y.to(self.dtype)
qinput, scale = fp8_quantize(input) qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,
self.qweight.t(), self.qweight.t(),