feat(server): opt dist ops

This commit is contained in:
OlivierDehaene 2023-06-08 16:25:24 +02:00
parent abd58ff82c
commit b67405bd8e
4 changed files with 49 additions and 14 deletions

View File

@ -265,6 +265,7 @@ class FlashNeoXLayer(nn.Module):
mlp_output = self.mlp(ln2_hidden_states) mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None return intermediate + hidden_states, None

View File

@ -440,6 +440,7 @@ class FlashRWLayer(nn.Module):
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual
@ -524,6 +525,7 @@ class FlashRWLargeLayer(nn.Module):
intermediate = attn_output + mlp_output intermediate = attn_output + mlp_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual

View File

@ -346,6 +346,8 @@ class FlashSantacoderModel(nn.Module):
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill # Prefill

View File

@ -1,3 +1,4 @@
import loguru
import torch import torch
import torch.distributed import torch.distributed
@ -158,14 +159,42 @@ class TensorParallelHead(SuperLayer):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input) world_size = self.process_group.size()
# Logits are sharded, so we need to gather them if world_size == 1:
world_output = [ return super().forward(input)
torch.empty_like(output) for _ in range(self.process_group.size())
] out_dim = self.linear.weight.shape[0]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1) if input.shape[0] == 1:
return world_output world_out = input.new_empty(1, out_dim * world_size)
local_out = world_out[:, :out_dim]
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
local_out = world_out[:out_dim].T
if isinstance(self.linear, FastLinear):
torch.mm(input, mat2=self.linear.weight.T, out=local_out)
elif isinstance(self.linear, Linear8bitLt):
bnb.matmul(
input,
self.linear.weight,
bias=None,
state=self.linear.state,
out=local_out,
)
else:
raise NotImplementedError
if input.shape[0] == 1:
torch.distributed.all_gather_into_tensor(
world_out, local_out, group=self.process_group
)
return world_out
torch.distributed.all_gather_into_tensor(
world_out, world_out[:out_dim], group=self.process_group
)
return world_out.T
class TensorParallelColumnLinear(SuperLayer): class TensorParallelColumnLinear(SuperLayer):
@ -211,6 +240,7 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -245,7 +275,7 @@ class TensorParallelEmbedding(nn.Module):
input - self.min_id, input - self.min_id,
) )
out = torch.nn.functional.embedding(input, self.weight) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce: if self.reduce and self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out