mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: improve fbgemm_gpu check and lints
This commit is contained in:
parent
382bf59f4f
commit
e216e53ea8
@ -20,7 +20,10 @@ FBGEMM_DYN_AVAILABLE = False
|
|||||||
|
|
||||||
|
|
||||||
def is_fbgemm_gpu_available():
|
def is_fbgemm_gpu_available():
|
||||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
try:
|
||||||
|
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
if is_fbgemm_gpu_available():
|
if is_fbgemm_gpu_available():
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
|
||||||
from text_generation_server.layers.marlin.gptq import (
|
from text_generation_server.layers.marlin.gptq import (
|
||||||
GPTQMarlinLinear,
|
GPTQMarlinLinear,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
B_meta = torch.cat(
|
B_meta = torch.cat(
|
||||||
@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
s = torch.cat(
|
s = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||||
|
Loading…
Reference in New Issue
Block a user