support input.shape 3

This commit is contained in:
OlivierDehaene 2023-06-08 16:29:01 +02:00
parent b67405bd8e
commit bdd811dd83

View File

@ -163,6 +163,7 @@ class TensorParallelHead(SuperLayer):
if world_size == 1:
return super().forward(input)
if input.shape == 2:
out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1:
@ -196,6 +197,14 @@ class TensorParallelHead(SuperLayer):
)
return world_out.T
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer):
@classmethod