diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index d2c46c58..54550ee3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module: """ if SYSTEM == "cuda": - major, minor = torch.cuda.get_device_capability() - if major == 8 and minor < 9: + major, _ = torch.cuda.get_device_capability() + if major == 8: from text_generation_server.layers.marlin import GPTQMarlinFP8Linear return GPTQMarlinFP8Linear @@ -188,6 +188,9 @@ class Fp8Linear(torch.nn.Module): dtype, ) -> None: super().__init__() + if FBGEMM_MM_AVAILABLE: + log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + self.dtype = dtype self.qweight = qweight self.scale = scale