feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

This commit is contained in:
OlivierDehaene 2023-02-01 14:43:59 +01:00 committed by GitHub
parent 404ed7a1f6
commit 2ad895a6cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 17 deletions

View File

@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder) - [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13` - [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
Other models are supported on a best effort basis using: Other models are supported on a best effort basis using:

View File

@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif name == "embed_out.weight": elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0] size = slice_.get_shape()[0]
block_size = size // world_size block_size = size // world_size
start = rank * block_size start = rank * block_size
@ -229,6 +229,7 @@ class GPTNeoxSharded(GPTNeox):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -238,7 +239,14 @@ class GPTNeoxSharded(GPTNeox):
# Logits are sharded, so we need to gather them # Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2) logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)

View File

@ -91,7 +91,7 @@ class NextTokenChooser:
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=str(device), device=device,
) )