diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 9c745647..59b08b55 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -20,7 +20,10 @@ FBGEMM_DYN_AVAILABLE = False def is_fbgemm_gpu_available(): - return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None + try: + return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None + except ModuleNotFoundError: + return False if is_fbgemm_gpu_available(): diff --git a/server/text_generation_server/layers/marlin/__init__.py b/server/text_generation_server/layers/marlin/__init__.py index 40147d59..8a143f96 100644 --- a/server/text_generation_server/layers/marlin/__init__.py +++ b/server/text_generation_server/layers/marlin/__init__.py @@ -1,6 +1,3 @@ -from typing import List, Tuple - -import torch from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.gptq import ( GPTQMarlinLinear, diff --git a/server/text_generation_server/layers/marlin/marlin.py b/server/text_generation_server/layers/marlin/marlin.py index db3ce2d7..89ebaca6 100644 --- a/server/text_generation_server/layers/marlin/marlin.py +++ b/server/text_generation_server/layers/marlin/marlin.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch import torch.nn as nn @@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader): ) except RuntimeError: raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" + "Cannot load `marlin` weight, make sure the model is already quantized" ) B_meta = torch.cat( @@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader): ) except RuntimeError: raise RuntimeError( - f"Cannot load `marlin` weight, make sure the model is already quantized" + "Cannot load `marlin` weight, make sure the model is already quantized" ) s = torch.cat( [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1