formatting

This commit is contained in:
OlivierDehaene 2023-02-01 11:48:18 +01:00
parent 1d0fa38cb8
commit 3149317fa1

View File

@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox):
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)