From 025f80dfd454b994966c230b0bce2bf7b04e22c6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:35:32 +0200 Subject: [PATCH] use marlin even on 89 --- server/text_generation_server/layers/fp8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index d2c46c58..54550ee3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module: """ if SYSTEM == "cuda": - major, minor = torch.cuda.get_device_capability() - if major == 8 and minor < 9: + major, _ = torch.cuda.get_device_capability() + if major == 8: from text_generation_server.layers.marlin import GPTQMarlinFP8Linear return GPTQMarlinFP8Linear @@ -188,6 +188,9 @@ class Fp8Linear(torch.nn.Module): dtype, ) -> None: super().__init__() + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + self.dtype = dtype self.qweight = qweight self.scale = scale