feat(server): allow gpt-neox models with odd vocab sizes to be sharded

This commit is contained in:
OlivierDehaene 2023-02-01 11:47:32 +01:00
parent 404ed7a1f6
commit 1d0fa38cb8
3 changed files with 30 additions and 26 deletions

View File

@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13`
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
Other models are supported on a best effort basis using:

View File

@ -33,7 +33,7 @@ except Exception as e:
class GPTNeox(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
class GPTNeoxSharded(GPTNeox):
def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox):
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight":
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox):
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
@ -227,18 +227,22 @@ class GPTNeoxSharded(GPTNeox):
module._buffers[param_name] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)

View File

@ -91,7 +91,7 @@ class NextTokenChooser:
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=str(device),
device=device,
)