Fix Phi-2 with tp>1

We were using the wrong parallelism in the up-projection.
This commit is contained in:
Daniël de Kok 2024-06-04 08:25:33 +00:00
parent df71aafdcc
commit f9c354d120

View File

@ -238,7 +238,7 @@ class PhiMLP(nn.Module):
)
# llama weights are up_proj and down_proj and bias=False
self.up_proj = TensorParallelRowLinear.load(
self.up_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,