From 325f9774fe7994e2611be3c331c8aea03aad8aa0 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 19 Apr 2024 16:19:47 +0000 Subject: [PATCH] reenable _custom_C.LLMM1 as the culprit was FA2 triton --- server/text_generation_server/utils/layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 27adc775..8e36f654 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -355,8 +355,7 @@ class FastLinearROCm(nn.Module): weight = self.weight bias = self.bias - # TODO: fix for TP>=2, this only works for TP=1 - if False and IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: + if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: batched = False if inp.dim() == 3: