From 2b25e9a94e8ca2c39d7eba45506d52e823f531da Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:33:53 +0200 Subject: [PATCH] disable triton on rocm --- .../models/metadata_kernels.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index 2357564e..17312efb 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -4,7 +4,19 @@ import triton import triton.language as tl from typing import List -from torch.utils._triton import has_triton +from torch.utils._triton import has_triton as has_triton_torch + +from text_generation_server.utils.import_utils import ( + SYSTEM, +) + + +def has_triton(): + # FIXME: it seems that has_triton_torch is bugged on RocM + # For now, only accept cuda + if SYSTEM == "cuda": + return has_triton_torch() + return False def block_tables_to_padded(