diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index 2357564e..17312efb 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -4,7 +4,19 @@ import triton import triton.language as tl from typing import List -from torch.utils._triton import has_triton +from torch.utils._triton import has_triton as has_triton_torch + +from text_generation_server.utils.import_utils import ( + SYSTEM, +) + + +def has_triton(): + # FIXME: it seems that has_triton_torch is bugged on RocM + # For now, only accept cuda + if SYSTEM == "cuda": + return has_triton_torch() + return False def block_tables_to_padded(