diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index afb75cab..6e63ae66 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -162,40 +162,29 @@ class TensorParallelHead(SuperLayer): if world_size == 1: return super().forward(input) - if input.shape == 2: + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] if input.shape[0] == 1: world_out = input.new_empty(1, out_dim * world_size) - local_out = world_out[:, :out_dim] + local_out = input.new_empty(1, out_dim) + gather_input = local_out else: world_out = input.new_empty(out_dim * world_size, input.shape[0]) - local_out = world_out[:out_dim].T + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T - if isinstance(self.linear, FastLinear): - torch.mm(input, mat2=self.linear.weight.T, out=local_out) - elif isinstance(self.linear, Linear8bitLt): - bnb.matmul( - input, - self.linear.weight, - bias=None, - state=self.linear.state, - out=local_out, - ) - else: - raise NotImplementedError - - if input.shape[0] == 1: - torch.distributed.all_gather_into_tensor( - world_out, local_out, group=self.process_group - ) - return world_out + torch.mm(input, self.linear.weight.T, out=local_out) torch.distributed.all_gather_into_tensor( - world_out, world_out[:out_dim], group=self.process_group + world_out, gather_input, group=self.process_group ) + + if input.shape[0] == 1: + return world_out return world_out.T + output = super().forward(input) world_output = [ torch.empty_like(output) for _ in range(self.process_group.size())