mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
support input.shape 3
This commit is contained in:
parent
b67405bd8e
commit
bdd811dd83
@ -163,38 +163,47 @@ class TensorParallelHead(SuperLayer):
|
|||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
out_dim = self.linear.weight.shape[0]
|
if input.shape == 2:
|
||||||
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
world_out = input.new_empty(1, out_dim * world_size)
|
world_out = input.new_empty(1, out_dim * world_size)
|
||||||
local_out = world_out[:, :out_dim]
|
local_out = world_out[:, :out_dim]
|
||||||
else:
|
else:
|
||||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||||
local_out = world_out[:out_dim].T
|
local_out = world_out[:out_dim].T
|
||||||
|
|
||||||
if isinstance(self.linear, FastLinear):
|
if isinstance(self.linear, FastLinear):
|
||||||
torch.mm(input, mat2=self.linear.weight.T, out=local_out)
|
torch.mm(input, mat2=self.linear.weight.T, out=local_out)
|
||||||
elif isinstance(self.linear, Linear8bitLt):
|
elif isinstance(self.linear, Linear8bitLt):
|
||||||
bnb.matmul(
|
bnb.matmul(
|
||||||
input,
|
input,
|
||||||
self.linear.weight,
|
self.linear.weight,
|
||||||
bias=None,
|
bias=None,
|
||||||
state=self.linear.state,
|
state=self.linear.state,
|
||||||
out=local_out,
|
out=local_out,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, local_out, group=self.process_group
|
||||||
|
)
|
||||||
|
return world_out
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
world_out, local_out, group=self.process_group
|
world_out, world_out[:out_dim], group=self.process_group
|
||||||
)
|
)
|
||||||
return world_out
|
return world_out.T
|
||||||
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
output = super().forward(input)
|
||||||
world_out, world_out[:out_dim], group=self.process_group
|
world_output = [
|
||||||
)
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
return world_out.T
|
]
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
return world_output
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(SuperLayer):
|
class TensorParallelColumnLinear(SuperLayer):
|
||||||
|
Loading…
Reference in New Issue
Block a user