mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
formatting
This commit is contained in:
parent
1d0fa38cb8
commit
3149317fa1
@ -239,10 +239,14 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
|
|
||||||
# Logits are sharded, so we need to gather them
|
# Logits are sharded, so we need to gather them
|
||||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
torch.distributed.all_gather(
|
||||||
|
logits, outputs.logits, group=self.process_group
|
||||||
|
)
|
||||||
logits = torch.cat(logits, dim=2)
|
logits = torch.cat(logits, dim=2)
|
||||||
|
|
||||||
return logits, outputs.past_key_values
|
return logits, outputs.past_key_values
|
||||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||||
else:
|
else:
|
||||||
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
|
return super(GPTNeoxSharded, self).forward(
|
||||||
|
input_ids, attention_mask, position_ids, past_key_values
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user