diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index bf5a0989..96d5f4a3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): - if FBGEMM_DYN_AVAILABLE: +def fp8_quantize( + weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False +): + if FBGEMM_DYN_AVAILABLE and not scalar: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype ) @@ -232,7 +234,7 @@ class Fp8Linear(torch.nn.Module): ) return y.to(self.dtype) - qinput, scale = fp8_quantize(input) + qinput, scale = fp8_quantize(input, scalar=True) output, _ = torch._scaled_mm( qinput, self.qweight.t(),