From 3d0c7b85febc97a54b2d9d2ef661aaa50771e73e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:45:26 +0200 Subject: [PATCH] fix(l4): fix fp8 logic on l4 --- server/text_generation_server/layers/fp8.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index bf5a0989..96d5f4a3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): - if FBGEMM_DYN_AVAILABLE: +def fp8_quantize( + weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False +): + if FBGEMM_DYN_AVAILABLE and not scalar: qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype ) @@ -232,7 +234,7 @@ class Fp8Linear(torch.nn.Module): ) return y.to(self.dtype) - qinput, scale = fp8_quantize(input) + qinput, scale = fp8_quantize(input, scalar=True) output, _ = torch._scaled_mm( qinput, self.qweight.t(),