reenable _custom_C.LLMM1 as the culprit was FA2 triton

This commit is contained in:
fxmarty 2024-04-19 16:19:47 +00:00
parent 81c27ba9c2
commit 325f9774fe

View File

@ -355,8 +355,7 @@ class FastLinearROCm(nn.Module):
weight = self.weight weight = self.weight
bias = self.bias bias = self.bias
# TODO: fix for TP>=2, this only works for TP=1 if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
if False and IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1:
batched = False batched = False
if inp.dim() == 3: if inp.dim() == 3: