mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)
This commit is contained in:
parent
404ed7a1f6
commit
2ad895a6cc
@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
|||||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||||
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
||||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13`
|
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
|
||||||
|
|
||||||
Other models are supported on a best effort basis using:
|
Other models are supported on a best effort basis using:
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
elif name == "embed_out.weight":
|
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
block_size = size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
type(module)
|
type(module)
|
||||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||||
and param_name == "weight"
|
and param_name == "weight"
|
||||||
):
|
):
|
||||||
tensor = Int8Params(
|
tensor = Int8Params(
|
||||||
tensor,
|
tensor,
|
||||||
@ -229,16 +229,24 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
if self.model.gpt_neox.tp_embeddings:
|
||||||
input_ids=input_ids,
|
outputs = self.model.forward(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
past_key_values=past_key_values,
|
attention_mask=attention_mask,
|
||||||
use_cache=True,
|
past_key_values=past_key_values,
|
||||||
)
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Logits are sharded, so we need to gather them
|
# Logits are sharded, so we need to gather them
|
||||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
torch.distributed.all_gather(
|
||||||
logits = torch.cat(logits, dim=2)
|
logits, outputs.logits, group=self.process_group
|
||||||
|
)
|
||||||
|
logits = torch.cat(logits, dim=2)
|
||||||
|
|
||||||
return logits, outputs.past_key_values
|
return logits, outputs.past_key_values
|
||||||
|
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||||
|
else:
|
||||||
|
return super(GPTNeoxSharded, self).forward(
|
||||||
|
input_ids, attention_mask, position_ids, past_key_values
|
||||||
|
)
|
||||||
|
@ -91,7 +91,7 @@ class NextTokenChooser:
|
|||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=str(device),
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user