mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
use marlin even on 89
This commit is contained in:
parent
473f968a01
commit
025f80dfd4
@ -32,8 +32,8 @@ def get_fp8_linear() -> torch.nn.Module:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
if major == 8 and minor < 9:
|
if major == 8:
|
||||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||||
|
|
||||||
return GPTQMarlinFP8Linear
|
return GPTQMarlinFP8Linear
|
||||||
@ -188,6 +188,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
dtype,
|
dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
|
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
Loading…
Reference in New Issue
Block a user