mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
formatting
This commit is contained in:
parent
1d0fa38cb8
commit
3149317fa1
@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox):
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
torch.distributed.all_gather(
|
||||
logits, outputs.logits, group=self.process_group
|
||||
)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||
else:
|
||||
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
|
||||
return super(GPTNeoxSharded, self).forward(
|
||||
input_ids, attention_mask, position_ids, past_key_values
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user