diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 63131dee..c6db32d3 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -10,7 +10,7 @@ from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "ipex": from .ipex import QuantLinear -elif SYSTEM == "cuda": +elif SYSTEM in {"cuda", "rocm"}: from .cuda import QuantLinear