use marlin even on 89

This commit is contained in:
OlivierDehaene 2024-07-23 10:35:32 +02:00
parent 473f968a01
commit 025f80dfd4
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module:
""" """
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
if major == 8 and minor < 9: if major == 8:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear return GPTQMarlinFP8Linear
@ -188,6 +188,9 @@ class Fp8Linear(torch.nn.Module):
dtype, dtype,
) -> None: ) -> None:
super().__init__() super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale self.scale = scale