Making bloom loadable with safetensors.

This commit is contained in:
Nicolas Patry 2022-10-21 18:02:04 +02:00
parent c837893370
commit 457c9038ff
No known key found for this signature in database
GPG Key ID: 798FF72A96CC526E

View File

@ -1,15 +1,17 @@
import torch import torch
import torch.distributed import torch.distributed
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Optional, Dict from typing import List, Tuple, Optional, Dict
from huggingface_hub import hf_hub_download, HfApi
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights
from bloom_inference.pb import generate_pb2 from bloom_inference.pb import generate_pb2
from bloom_inference.prepare_weights import prepare_weights, match_suffix
from bloom_inference.utils import ( from bloom_inference.utils import (
StoppingCriteria, StoppingCriteria,
NextTokenChooser, NextTokenChooser,
@ -368,6 +370,149 @@ class BLOOM:
return generated_texts, next_batch return generated_texts, next_batch
def dl_weights(rank, model_id):
api = HfApi()
info = api.model_info(model_id)
filenames = set(
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
)
return [hf_hub_download(model_id, filename=filename) for filename in filenames]
def set_tensor(model, full_name, tensor):
splits = full_name.split(".")
for split in splits[:-1]:
model = getattr(model, split)
tensor_name = splits[-1]
with torch.no_grad():
model._parameters[tensor_name] = tensor
def load(model, filenames, tp_rank, tp_world_size):
parameters = dict(model.named_parameters())
for filename in filenames:
with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f:
for name in f.keys():
full_name = f"transformer.{name}"
current_tensor = parameters[full_name]
handled = False
for suffix in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"self_attention.query_key_value.weight",
"mlp.dense_h_to_4h.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
]:
if name.endswith(suffix):
slice_ = f.get_slice(name)
if suffix in {
"mlp.dense_4h_to_h.weight",
"self_attention.dense.weight",
}:
size = slice_.get_shape()[1]
block_size = size // tp_world_size
start = tp_rank * block_size
stop = (tp_rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
size = slice_.get_shape()[0]
block_size = size // tp_world_size
start = tp_rank * block_size
stop = (tp_rank + 1) * block_size
tensor = slice_[start:stop]
if name.endswith(".weight") and not name.endswith(
"word_embeddings.weight"
):
tensor = tensor.transpose(1, 0)
handled = True
break
if not handled:
tensor = f.get_tensor(name)
tensor = tensor.contiguous()
if tp_rank != 0 and (
name.endswith("self_attention.dense.bias")
or name.endswith("mlp.dense_4h_to_h.bias")
):
# XXX: Hack for Rowlinear to add the bias only once.
set_tensor(model, full_name, torch.zeros_like(tensor))
else:
set_tensor(model, full_name, tensor)
if name == "word_embeddings.weight":
set_tensor(model, "lm_head.weight", tensor)
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""
imported from `accelerate` to not depend on it.
"""
old_register_parameter = torch.nn.Module.register_parameter
if include_buffers:
old_register_buffer = torch.nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)
def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
# Patch tensor creation
if include_buffers:
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
else:
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = torch.device("meta")
return fn(*args, **kwargs)
return wrapper
try:
torch.nn.Module.register_parameter = register_empty_parameter
if include_buffers:
torch.nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
if include_buffers:
torch.nn.Module.register_buffer = old_register_buffer
for (
torch_function_name,
old_torch_function,
) in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
class BLOOMSharded(BLOOM): class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, shard_directory: Path): def __init__(self, model_name: str, shard_directory: Path):
super(BLOOM, self).__init__() super(BLOOM, self).__init__()
@ -382,25 +527,7 @@ class BLOOMSharded(BLOOM):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# shard state_dict filenames = dl_weights(self.rank, model_name)
if self.master:
# TODO @thomasw21 do some caching
shard_state_dict_paths = prepare_weights(
model_name,
shard_directory / "cache",
shard_directory,
tp_world_size=self.world_size,
)
shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths
]
else:
shard_state_dict_paths = [None] * self.world_size
torch.distributed.broadcast_object_list(
shard_state_dict_paths, src=0, group=self.process_group
)
shard_state_dict_path = shard_state_dict_paths[self.rank]
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True model_name, slow_but_exact=False, tp_parallel=True
@ -415,33 +542,13 @@ class BLOOMSharded(BLOOM):
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
with set_default_dtype(dtype): with set_default_dtype(dtype):
with no_init_weights(): with init_empty_weights():
# we can probably set the device to `meta` here? # we can probably set the device to `meta` here?
model = AutoModelForCausalLM.from_config(config).to(dtype) model = AutoModelForCausalLM.from_config(config).to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model") # print_rank_0(f"Initialized model")
state_dict = torch.load(shard_state_dict_path) load(model, filenames, self.rank, self.process_group.size())
# TODO @thomasw21: HACK in order to transpose all weight prior
for key in state_dict.keys():
do_transpose = False
if not match_suffix(key, "weight"):
continue
for potential_suffix in [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"dense_h_to_4h.weight",
"dense_4h_to_h.weight",
]:
if match_suffix(key, potential_suffix):
do_transpose = True
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict, strict=False)
model.tie_weights()
self.model = model.to(self.device).eval() self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size() self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)