disable triton on rocm

This commit is contained in:
OlivierDehaene 2024-10-25 11:33:53 +02:00
parent a7465ba67d
commit 2b25e9a94e
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -4,7 +4,19 @@ import triton
import triton.language as tl import triton.language as tl
from typing import List from typing import List
from torch.utils._triton import has_triton from torch.utils._triton import has_triton as has_triton_torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
def has_triton():
# FIXME: it seems that has_triton_torch is bugged on RocM
# For now, only accept cuda
if SYSTEM == "cuda":
return has_triton_torch()
return False
def block_tables_to_padded( def block_tables_to_padded(