From f9c354d12059bea26a6bb1bb26427bf918c239e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 4 Jun 2024 08:25:33 +0000 Subject: [PATCH] Fix Phi-2 with `tp>1` We were using the wrong parallelism in the up-projection. --- .../models/custom_modeling/flash_phi_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index af3206dd..53d3ea42 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -238,7 +238,7 @@ class PhiMLP(nn.Module): ) # llama weights are up_proj and down_proj and bias=False - self.up_proj = TensorParallelRowLinear.load( + self.up_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.fc1", weights=weights,