mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
|
|
def match_suffix(text, suffix):
|
|
return text[-len(suffix) :] == suffix
|
|
|
|
|
|
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
|
|
"""BLOOM specific sharding mechanism"""
|
|
save_paths = [
|
|
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
|
for tp_rank in range(tp_world_size)
|
|
]
|
|
if all(save_path.exists() for save_path in save_paths):
|
|
print("Loading already cached values")
|
|
return save_paths
|
|
|
|
model: nn.Module = AutoModelForCausalLM.from_pretrained(
|
|
model_name, torch_dtype=dtype, local_files_only=True
|
|
)
|
|
|
|
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
|
state_dict = model.state_dict()
|
|
keys = list(state_dict.keys())
|
|
for state_name in keys:
|
|
print(state_name)
|
|
state = state_dict[state_name]
|
|
if any(
|
|
match_suffix(state_name, candidate)
|
|
for candidate in [
|
|
"self_attention.query_key_value.weight",
|
|
"self_attention.query_key_value.bias",
|
|
"mlp.dense_h_to_4h.weight",
|
|
"mlp.dense_h_to_4h.bias",
|
|
"transformer.word_embeddings.weight",
|
|
"lm_head.weight",
|
|
]
|
|
):
|
|
output_size = state.shape[0]
|
|
assert output_size % tp_world_size == 0
|
|
block_size = output_size // tp_world_size
|
|
sharded_weights = torch.split(state, block_size, dim=0)
|
|
assert len(sharded_weights) == tp_world_size
|
|
for tp_rank, shard in enumerate(sharded_weights):
|
|
assert shard.shape[0] == block_size
|
|
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
|
elif any(
|
|
match_suffix(state_name, candidate)
|
|
for candidate in [
|
|
"self_attention.dense.weight",
|
|
"mlp.dense_4h_to_h.weight",
|
|
"lm_head.weight",
|
|
]
|
|
):
|
|
input_size = state.shape[1]
|
|
assert input_size % tp_world_size == 0
|
|
block_size = input_size // tp_world_size
|
|
sharded_weights = torch.split(state, block_size, dim=1)
|
|
assert len(sharded_weights) == tp_world_size
|
|
for tp_rank, shard in enumerate(sharded_weights):
|
|
assert shard.shape[1] == block_size
|
|
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
|
elif any(
|
|
match_suffix(state_name, candidate)
|
|
for candidate in [
|
|
"self_attention.dense.bias",
|
|
"mlp.dense_4h_to_h.bias",
|
|
]
|
|
):
|
|
shards_state_dicts[0][state_name] = state.detach().clone()
|
|
for tp_rank in range(1, tp_world_size):
|
|
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
|
|
else:
|
|
# We duplicate parameters across tp ranks
|
|
for tp_rank in range(tp_world_size):
|
|
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
|
|
|
|
del state_dict[state_name] # delete key from state_dict
|
|
del state # delete tensor
|
|
|
|
# we save state_dict
|
|
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
|
zip(save_paths, shards_state_dicts)
|
|
):
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
torch.save(shard_state_dict, save_path)
|
|
save_paths.append(save_path)
|
|
|
|
return save_paths
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_name = "bigscience/bloom"
|
|
save_path = Path("/data/shards")
|
|
tp_world_size = 8
|
|
dtype = torch.bfloat16
|
|
|
|
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)
|