support input.shape 3

This commit is contained in:
OlivierDehaene 2023-06-08 16:29:01 +02:00
parent b67405bd8e
commit bdd811dd83

View File

@ -163,38 +163,47 @@ class TensorParallelHead(SuperLayer):
if world_size == 1: if world_size == 1:
return super().forward(input) return super().forward(input)
out_dim = self.linear.weight.shape[0] if input.shape == 2:
out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1: if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size) world_out = input.new_empty(1, out_dim * world_size)
local_out = world_out[:, :out_dim] local_out = world_out[:, :out_dim]
else: else:
world_out = input.new_empty(out_dim * world_size, input.shape[0]) world_out = input.new_empty(out_dim * world_size, input.shape[0])
local_out = world_out[:out_dim].T local_out = world_out[:out_dim].T
if isinstance(self.linear, FastLinear): if isinstance(self.linear, FastLinear):
torch.mm(input, mat2=self.linear.weight.T, out=local_out) torch.mm(input, mat2=self.linear.weight.T, out=local_out)
elif isinstance(self.linear, Linear8bitLt): elif isinstance(self.linear, Linear8bitLt):
bnb.matmul( bnb.matmul(
input, input,
self.linear.weight, self.linear.weight,
bias=None, bias=None,
state=self.linear.state, state=self.linear.state,
out=local_out, out=local_out,
) )
else: else:
raise NotImplementedError raise NotImplementedError
if input.shape[0] == 1:
torch.distributed.all_gather_into_tensor(
world_out, local_out, group=self.process_group
)
return world_out
if input.shape[0] == 1:
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
world_out, local_out, group=self.process_group world_out, world_out[:out_dim], group=self.process_group
) )
return world_out return world_out.T
torch.distributed.all_gather_into_tensor( output = super().forward(input)
world_out, world_out[:out_dim], group=self.process_group world_output = [
) torch.empty_like(output) for _ in range(self.process_group.size())
return world_out.T ]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):