mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
support input.shape 3
This commit is contained in:
parent
b67405bd8e
commit
bdd811dd83
@ -163,6 +163,7 @@ class TensorParallelHead(SuperLayer):
|
||||
if world_size == 1:
|
||||
return super().forward(input)
|
||||
|
||||
if input.shape == 2:
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
if input.shape[0] == 1:
|
||||
@ -196,6 +197,14 @@ class TensorParallelHead(SuperLayer):
|
||||
)
|
||||
return world_out.T
|
||||
|
||||
output = super().forward(input)
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user