From b67405bd8eaf2a778cadc0414e8040ef2137676f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Jun 2023 16:25:24 +0200 Subject: [PATCH] feat(server): opt dist ops --- .../custom_modeling/flash_neox_modeling.py | 3 +- .../custom_modeling/flash_rw_modeling.py | 6 ++- .../flash_santacoder_modeling.py | 4 +- server/text_generation_server/utils/layers.py | 50 +++++++++++++++---- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0fe43bcb..c045f16e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -265,7 +265,8 @@ class FlashNeoXLayer(nn.Module): mlp_output = self.mlp(ln2_hidden_states) intermediate = mlp_output + attn_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate + hidden_states, None else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 55195162..af9fa548 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -440,7 +440,8 @@ class FlashRWLayer(nn.Module): mlp_output = self.mlp(ln_hidden_states) intermediate = mlp_output + attn_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual else: @@ -524,7 +525,8 @@ class FlashRWLargeLayer(nn.Module): intermediate = attn_output + mlp_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 888a6066..fcf6be68 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -346,7 +346,9 @@ class FlashSantacoderModel(nn.Module): pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) - torch.distributed.all_reduce(hidden_states, group=self.process_group) + + if self.process_group.size() > 1: + torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ee32a0dc..3ec04963 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,3 +1,4 @@ +import loguru import torch import torch.distributed @@ -158,14 +159,42 @@ class TensorParallelHead(SuperLayer): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - output = super().forward(input) - # Logits are sharded, so we need to gather them - world_output = [ - torch.empty_like(output) for _ in range(self.process_group.size()) - ] - torch.distributed.all_gather(world_output, output, group=self.process_group) - world_output = torch.cat(world_output, dim=-1) - return world_output + world_size = self.process_group.size() + if world_size == 1: + return super().forward(input) + + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = world_out[:, :out_dim] + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + local_out = world_out[:out_dim].T + + if isinstance(self.linear, FastLinear): + torch.mm(input, mat2=self.linear.weight.T, out=local_out) + elif isinstance(self.linear, Linear8bitLt): + bnb.matmul( + input, + self.linear.weight, + bias=None, + state=self.linear.state, + out=local_out, + ) + else: + raise NotImplementedError + + if input.shape[0] == 1: + torch.distributed.all_gather_into_tensor( + world_out, local_out, group=self.process_group + ) + return world_out + + torch.distributed.all_gather_into_tensor( + world_out, world_out[:out_dim], group=self.process_group + ) + return world_out.T class TensorParallelColumnLinear(SuperLayer): @@ -211,7 +240,8 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -245,7 +275,7 @@ class TensorParallelEmbedding(nn.Module): input - self.min_id, ) out = torch.nn.functional.embedding(input, self.weight) - if self.reduce: + if self.reduce and self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) return out