From bdd811dd832bc669c3547fd7f5be020e7c9c02a8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 16:29:01 +0200 Subject: [PATCH] support input.shape 3 --- server/text_generation_server/utils/layers.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 3ec04963..e64b6b0c 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -163,38 +163,47 @@ class TensorParallelHead(SuperLayer): if world_size == 1: return super().forward(input) - out_dim = self.linear.weight.shape[0] + if input.shape == 2: + out_dim = self.linear.weight.shape[0] - if input.shape[0] == 1: - world_out = input.new_empty(1, out_dim * world_size) - local_out = world_out[:, :out_dim] - else: - world_out = input.new_empty(out_dim * world_size, input.shape[0]) - local_out = world_out[:out_dim].T + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = world_out[:, :out_dim] + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + local_out = world_out[:out_dim].T - if isinstance(self.linear, FastLinear): - torch.mm(input, mat2=self.linear.weight.T, out=local_out) - elif isinstance(self.linear, Linear8bitLt): - bnb.matmul( - input, - self.linear.weight, - bias=None, - state=self.linear.state, - out=local_out, - ) - else: - raise NotImplementedError + if isinstance(self.linear, FastLinear): + torch.mm(input, mat2=self.linear.weight.T, out=local_out) + elif isinstance(self.linear, Linear8bitLt): + bnb.matmul( + input, + self.linear.weight, + bias=None, + state=self.linear.state, + out=local_out, + ) + else: + raise NotImplementedError + + if input.shape[0] == 1: + torch.distributed.all_gather_into_tensor( + world_out, local_out, group=self.process_group + ) + return world_out - if input.shape[0] == 1: torch.distributed.all_gather_into_tensor( - world_out, local_out, group=self.process_group + world_out, world_out[:out_dim], group=self.process_group ) - return world_out + return world_out.T - torch.distributed.all_gather_into_tensor( - world_out, world_out[:out_dim], group=self.process_group - ) - return world_out.T + output = super().forward(input) + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + return world_output class TensorParallelColumnLinear(SuperLayer):