From b05ec96b0e18d12efebae6c67b38033ab4527d58 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 19:29:22 +0200 Subject: [PATCH] remove quant --- server/text_generation_server/utils/layers.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index afb75cab..6e63ae66 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -162,40 +162,29 @@ class TensorParallelHead(SuperLayer): if world_size == 1: return super().forward(input) - if input.shape == 2: + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] if input.shape[0] == 1: world_out = input.new_empty(1, out_dim * world_size) - local_out = world_out[:, :out_dim] + local_out = input.new_empty(1, out_dim) + gather_input = local_out else: world_out = input.new_empty(out_dim * world_size, input.shape[0]) - local_out = world_out[:out_dim].T + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T - if isinstance(self.linear, FastLinear): - torch.mm(input, mat2=self.linear.weight.T, out=local_out) - elif isinstance(self.linear, Linear8bitLt): - bnb.matmul( - input, - self.linear.weight, - bias=None, - state=self.linear.state, - out=local_out, - ) - else: - raise NotImplementedError - - if input.shape[0] == 1: - torch.distributed.all_gather_into_tensor( - world_out, local_out, group=self.process_group - ) - return world_out + torch.mm(input, self.linear.weight.T, out=local_out) torch.distributed.all_gather_into_tensor( - world_out, world_out[:out_dim], group=self.process_group + world_out, gather_input, group=self.process_group ) + + if input.shape[0] == 1: + return world_out return world_out.T + output = super().forward(input) world_output = [ torch.empty_like(output) for _ in range(self.process_group.size())