Making bloom loadable with safetensors.

This commit is contained in:
Nicolas Patry 2022-10-21 18:02:04 +02:00
parent c837893370
commit 457c9038ff
No known key found for this signature in database
GPG Key ID: 798FF72A96CC526E

View File

@ -1,15 +1,17 @@
import torch
import torch.distributed
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional, Dict
from huggingface_hub import hf_hub_download, HfApi
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights
from bloom_inference.pb import generate_pb2
from bloom_inference.prepare_weights import prepare_weights, match_suffix
from bloom_inference.utils import (
StoppingCriteria,
NextTokenChooser,
@ -368,6 +370,149 @@ class BLOOM:
return generated_texts, next_batch
def dl_weights(rank, model_id):
api = HfApi()
info = api.model_info(model_id)
filenames = set(
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
)
return [hf_hub_download(model_id, filename=filename) for filename in filenames]
def set_tensor(model, full_name, tensor):
splits = full_name.split(".")
for split in splits[:-1]:
model = getattr(model, split)
tensor_name = splits[-1]
with torch.no_grad():
model._parameters[tensor_name] = tensor
def load(model, filenames, tp_rank, tp_world_size):
parameters = dict(model.named_parameters())
for filename in filenames:
with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f:
for name in f.keys():
full_name = f"transformer.{name}"
current_tensor = parameters[full_name]
handled = False
for suffix in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"self_attention.query_key_value.weight",
"mlp.dense_h_to_4h.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
]:
if name.endswith(suffix):
slice_ = f.get_slice(name)
if suffix in {
"mlp.dense_4h_to_h.weight",
"self_attention.dense.weight",
}:
size = slice_.get_shape()[1]
block_size = size // tp_world_size
start = tp_rank * block_size
stop = (tp_rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
size = slice_.get_shape()[0]
block_size = size // tp_world_size
start = tp_rank * block_size
stop = (tp_rank + 1) * block_size
tensor = slice_[start:stop]
if name.endswith(".weight") and not name.endswith(
"word_embeddings.weight"
):
tensor = tensor.transpose(1, 0)
handled = True
break
if not handled:
tensor = f.get_tensor(name)
tensor = tensor.contiguous()
if tp_rank != 0 and (
name.endswith("self_attention.dense.bias")
or name.endswith("mlp.dense_4h_to_h.bias")
):
# XXX: Hack for Rowlinear to add the bias only once.
set_tensor(model, full_name, torch.zeros_like(tensor))
else:
set_tensor(model, full_name, tensor)
if name == "word_embeddings.weight":
set_tensor(model, "lm_head.weight", tensor)
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""
imported from `accelerate` to not depend on it.
"""
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)
def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
# Patch tensor creation
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = torch.device("meta")
return fn(*args, **kwargs)
return wrapper
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, shard_directory: Path):
super(BLOOM, self).__init__()
@ -382,25 +527,7 @@ class BLOOMSharded(BLOOM):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# shard state_dict
if self.master:
# TODO @thomasw21 do some caching
shard_state_dict_paths = prepare_weights(
model_name,
shard_directory / "cache",
shard_directory,
tp_world_size=self.world_size,
)
shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths
]
else:
shard_state_dict_paths = [None] * self.world_size
torch.distributed.broadcast_object_list(
shard_state_dict_paths, src=0, group=self.process_group
)
shard_state_dict_path = shard_state_dict_paths[self.rank]
filenames = dl_weights(self.rank, model_name)
config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True
@ -415,33 +542,13 @@ class BLOOMSharded(BLOOM):
torch.backends.cudnn.allow_tf32 = True
with set_default_dtype(dtype):
with no_init_weights():
with init_empty_weights():
# we can probably set the device to `meta` here?
model = AutoModelForCausalLM.from_config(config).to(dtype)
torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model")
state_dict = torch.load(shard_state_dict_path)
# TODO @thomasw21: HACK in order to transpose all weight prior
for key in state_dict.keys():
do_transpose = False
if not match_suffix(key, "weight"):
continue
for potential_suffix in [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"dense_h_to_4h.weight",
"dense_4h_to_h.weight",
]:
if match_suffix(key, potential_suffix):
do_transpose = True
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict, strict=False)
model.tie_weights()
load(model, filenames, self.rank, self.process_group.size())
self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group)