From 1a3aa08fa0caa26c2c97747e23571b51f4bca44c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 9 May 2023 18:09:42 +0200 Subject: [PATCH] revert change on all gather --- .../models/custom_modeling/flash_llama_modeling.py | 9 +++------ .../models/custom_modeling/flash_neox_modeling.py | 9 +++------ .../models/custom_modeling/flash_santacoder_modeling.py | 9 +++++---- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 637c95df..1293124a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -665,12 +665,9 @@ class FlashLlamaForCausalLM(torch.nn.Module): if self.model.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = logits.new_empty( - (logits.shape[0], logits.shape[1] * self.world_size) - ) - torch.distributed.all_gather_into_tensor( - world_logits, logits, group=self.process_group - ) + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 8d93301e..ae1465ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -741,12 +741,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): if self.gpt_neox.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = logits.new_empty( - (logits.shape[0], logits.shape[1] * self.world_size) - ) - torch.distributed.all_gather_into_tensor( - world_logits, logits, group=self.process_group - ) + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) return world_logits, present return logits, present diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 597eaef1..20ad8385 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -581,12 +581,13 @@ class FlashSantacoderForCausalLM(nn.Module): if self.transformer.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = logits.new_empty( - (logits.shape[0], logits.shape[1] * self.transformer.tp_world_size) - ) - torch.distributed.all_gather_into_tensor( + world_logits = [ + torch.empty_like(logits) for _ in range(self.transformer.tp_world_size) + ] + torch.distributed.all_gather( world_logits, logits, group=self.transformer.process_group ) + world_logits = torch.cat(world_logits, dim=1) return world_logits, present