revert change on all gather

This commit is contained in:
OlivierDehaene 2023-05-09 18:09:42 +02:00
parent bf5990ee9e
commit 1a3aa08fa0
3 changed files with 11 additions and 16 deletions

View File

@ -665,12 +665,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if self.model.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = logits.new_empty(
(logits.shape[0], logits.shape[1] * self.world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits, group=self.process_group
)
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -741,12 +741,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
if self.gpt_neox.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = logits.new_empty(
(logits.shape[0], logits.shape[1] * self.world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits, group=self.process_group
)
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -581,12 +581,13 @@ class FlashSantacoderForCausalLM(nn.Module):
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = logits.new_empty(
(logits.shape[0], logits.shape[1] * self.transformer.tp_world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits = [
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present