Nicolas Patry 2022-10-21 20:47:57 +02:00
parent 457c9038ff
commit 604b18bec2
No known key found for this signature in database
GPG Key ID: 798FF72A96CC526E

View File

@ -9,6 +9,11 @@ from typing import List, Tuple, Optional, Dict
from huggingface_hub import hf_hub_download, HfApi from huggingface_hub import hf_hub_download, HfApi
from safetensors import safe_open from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from bloom_inference.pb import generate_pb2 from bloom_inference.pb import generate_pb2
@ -370,87 +375,91 @@ class BLOOM:
return generated_texts, next_batch return generated_texts, next_batch
def dl_weights(rank, model_id): def dl_weights(group, model_id):
rank = group.rank()
api = HfApi() api = HfApi()
info = api.model_info(model_id) info = api.model_info(model_id)
filenames = set( filenames = [
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors") s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
) ]
return [hf_hub_download(model_id, filename=filename) for filename in filenames] # Download the files only on rank 0
if rank == 0:
# XXX: You might want to try and launch these in a multiprocessing.Pool to download the files faster.
[
hf_hub_download(model_id, filename=filename, local_files_only=True)
for filename in filenames
]
else:
pass
torch.distributed.barrier(group=group)
# At this point the files should be in cache
return [
hf_hub_download(model_id, filename=filename, local_files_only=True)
for filename in filenames
]
def set_tensor(model, full_name, tensor): def load(model, filenames, group):
splits = full_name.split(".") tp_rank = group.rank()
for split in splits[:-1]: tp_world_size = group.size()
model = getattr(model, split)
tensor_name = splits[-1]
with torch.no_grad():
model._parameters[tensor_name] = tensor
def load(model, filenames, tp_rank, tp_world_size):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f: with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f:
for name in f.keys(): for name in f.keys():
full_name = f"transformer.{name}" full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name] current_tensor = parameters[full_name]
handled = False
for suffix in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"self_attention.query_key_value.weight",
"mlp.dense_h_to_4h.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
]:
if name.endswith(suffix):
slice_ = f.get_slice(name) slice_ = f.get_slice(name)
if suffix in {
"mlp.dense_4h_to_h.weight", if isinstance(module, TensorParallelColumnLinear):
"self_attention.dense.weight", if param_name == "weight":
}: size = slice_.get_shape()[0]
size = slice_.get_shape()[1]
block_size = size // tp_world_size block_size = size // tp_world_size
start = tp_rank * block_size start = tp_rank * block_size
stop = (tp_rank + 1) * block_size stop = (tp_rank + 1) * block_size
tensor = slice_[:, start:stop] tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else: else:
size = slice_.get_shape()[0] size = slice_.get_shape()[0]
block_size = size // tp_world_size block_size = size // tp_world_size
start = tp_rank * block_size start = tp_rank * block_size
stop = (tp_rank + 1) * block_size stop = (tp_rank + 1) * block_size
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if name.endswith(".weight") and not name.endswith( if param_name == "weight":
"word_embeddings.weight" size = slice_.get_shape()[1]
): block_size = size // tp_world_size
start = tp_rank * block_size
stop = (tp_rank + 1) * block_size
tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0) tensor = tensor.transpose(1, 0)
handled = True
break
if not handled:
tensor = f.get_tensor(name)
tensor = tensor.contiguous()
if tp_rank != 0 and (
name.endswith("self_attention.dense.bias")
or name.endswith("mlp.dense_4h_to_h.bias")
):
# XXX: Hack for Rowlinear to add the bias only once.
set_tensor(model, full_name, torch.zeros_like(tensor))
else: else:
set_tensor(model, full_name, tensor) tensor = slice_[:]
if name == "word_embeddings.weight": # XXX: Hack for Rowlinear to add the bias only once.
set_tensor(model, "lm_head.weight", tensor) if tp_rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // tp_world_size
start = tp_rank * block_size
stop = (tp_rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape: if current_tensor.shape != tensor.shape:
raise ValueError( raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous()
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
@contextmanager @contextmanager
def init_empty_weights(include_buffers: bool = False): def init_empty_weights(include_buffers: bool = False):
@ -527,7 +536,7 @@ class BLOOMSharded(BLOOM):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
filenames = dl_weights(self.rank, model_name) filenames = dl_weights(self.process_group, model_name)
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True model_name, slow_but_exact=False, tp_parallel=True
@ -548,7 +557,7 @@ class BLOOMSharded(BLOOM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model") # print_rank_0(f"Initialized model")
load(model, filenames, self.rank, self.process_group.size()) load(model, filenames, self.process_group)
self.model = model.to(self.device).eval() self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size() self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)