diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e36f6542..27adc775f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -355,7 +355,8 @@ class FastLinearROCm(nn.Module): weight = self.weight bias = self.bias - if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: + # TODO: fix for TP>=2, this only works for TP=1 + if False and IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: batched = False if inp.dim() == 3: