disable _custom_C.LLMM1 as it is broken for TP>=2

This commit is contained in:
fxmarty 2024-04-19 15:59:31 +00:00
parent 562cd4b06e
commit 81c27ba9c2

View File

@ -355,7 +355,8 @@ class FastLinearROCm(nn.Module):
weight = self.weight
bias = self.bias
if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
# TODO: fix for TP>=2, this only works for TP=1
if False and IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
batched = False
if inp.dim() == 3: