fix: improve fbgemm_gpu check and lints

This commit is contained in:
drbh 2024-07-24 15:23:10 +00:00
parent 382bf59f4f
commit e216e53ea8
3 changed files with 7 additions and 7 deletions

View File

@ -20,7 +20,10 @@ FBGEMM_DYN_AVAILABLE = False
def is_fbgemm_gpu_available(): def is_fbgemm_gpu_available():
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available(): if is_fbgemm_gpu_available():

View File

@ -1,6 +1,3 @@
from typing import List, Tuple
import torch
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import ( from text_generation_server.layers.marlin.gptq import (
GPTQMarlinLinear, GPTQMarlinLinear,

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized" "Cannot load `marlin` weight, make sure the model is already quantized"
) )
B_meta = torch.cat( B_meta = torch.cat(
@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized" "Cannot load `marlin` weight, make sure the model is already quantized"
) )
s = torch.cat( s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1