mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
feat(server): allow gpt-neox models with odd vocab sizes to be sharded
This commit is contained in:
parent
404ed7a1f6
commit
1d0fa38cb8
@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13`
|
||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
|
||||
|
||||
Other models are supported on a best effort basis using:
|
||||
|
||||
|
@ -33,7 +33,7 @@ except Exception as e:
|
||||
|
||||
class GPTNeox(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
|
||||
|
||||
class GPTNeoxSharded(GPTNeox):
|
||||
def __init__(
|
||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -101,17 +101,17 @@ class GPTNeoxSharded(GPTNeox):
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "embed_out.weight":
|
||||
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
@ -158,8 +158,8 @@ class GPTNeoxSharded(GPTNeox):
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
@ -227,18 +227,22 @@ class GPTNeoxSharded(GPTNeox):
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
if self.model.gpt_neox.tp_embeddings:
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
return logits, outputs.past_key_values
|
||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||
else:
|
||||
return super(GPTNeoxSharded, self).forward(input_ids, attention_mask, position_ids, past_key_values)
|
||||
|
@ -91,7 +91,7 @@ class NextTokenChooser:
|
||||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=str(device),
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user