remove quant

This commit is contained in:
OlivierDehaene 2023-06-08 19:29:22 +02:00
parent 219be4f488
commit b05ec96b0e

View File

@ -162,40 +162,29 @@ class TensorParallelHead(SuperLayer):
if world_size == 1: if world_size == 1:
return super().forward(input) return super().forward(input)
if input.shape == 2: if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0] out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1: if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size) world_out = input.new_empty(1, out_dim * world_size)
local_out = world_out[:, :out_dim] local_out = input.new_empty(1, out_dim)
gather_input = local_out
else: else:
world_out = input.new_empty(out_dim * world_size, input.shape[0]) world_out = input.new_empty(out_dim * world_size, input.shape[0])
local_out = world_out[:out_dim].T gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
if isinstance(self.linear, FastLinear): torch.mm(input, self.linear.weight.T, out=local_out)
torch.mm(input, mat2=self.linear.weight.T, out=local_out)
elif isinstance(self.linear, Linear8bitLt):
bnb.matmul(
input,
self.linear.weight,
bias=None,
state=self.linear.state,
out=local_out,
)
else:
raise NotImplementedError
if input.shape[0] == 1:
torch.distributed.all_gather_into_tensor(
world_out, local_out, group=self.process_group
)
return world_out
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
world_out, world_out[:out_dim], group=self.process_group world_out, gather_input, group=self.process_group
) )
if input.shape[0] == 1:
return world_out
return world_out.T return world_out.T
output = super().forward(input) output = super().forward(input)
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())