From 604b18bec29e381205d3e0debb7a624def4b3119 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 21 Oct 2022 20:47:57 +0200 Subject: [PATCH] Reworked follwoing https://github.com/huggingface/transformers_bloom_parallel/pull/7 --- server/bloom_inference/model.py | 135 +++++++++++++++++--------------- 1 file changed, 72 insertions(+), 63 deletions(-) diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index f1e8f167..c58ad394 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -9,6 +9,11 @@ from typing import List, Tuple, Optional, Dict from huggingface_hub import hf_hub_download, HfApi from safetensors import safe_open from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers.models.bloom.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) from bloom_inference.pb import generate_pb2 @@ -370,87 +375,91 @@ class BLOOM: return generated_texts, next_batch -def dl_weights(rank, model_id): +def dl_weights(group, model_id): + rank = group.rank() api = HfApi() info = api.model_info(model_id) - filenames = set( + filenames = [ s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors") - ) - return [hf_hub_download(model_id, filename=filename) for filename in filenames] + ] + # Download the files only on rank 0 + if rank == 0: + # XXX: You might want to try and launch these in a multiprocessing.Pool to download the files faster. + [ + hf_hub_download(model_id, filename=filename, local_files_only=True) + for filename in filenames + ] + else: + pass + torch.distributed.barrier(group=group) + # At this point the files should be in cache + return [ + hf_hub_download(model_id, filename=filename, local_files_only=True) + for filename in filenames + ] -def set_tensor(model, full_name, tensor): - splits = full_name.split(".") - for split in splits[:-1]: - model = getattr(model, split) - tensor_name = splits[-1] - - with torch.no_grad(): - model._parameters[tensor_name] = tensor - - -def load(model, filenames, tp_rank, tp_world_size): +def load(model, filenames, group): + tp_rank = group.rank() + tp_world_size = group.size() parameters = dict(model.named_parameters()) for filename in filenames: with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f: for name in f.keys(): full_name = f"transformer.{name}" + + module_name, param_name = full_name.rsplit(".", 1) + module = model.get_submodule(module_name) current_tensor = parameters[full_name] - handled = False - for suffix in [ - "self_attention.dense.weight", - "mlp.dense_4h_to_h.weight", - "self_attention.query_key_value.weight", - "mlp.dense_h_to_4h.weight", - "self_attention.query_key_value.bias", - "mlp.dense_h_to_4h.bias", - "word_embeddings.weight", - ]: - if name.endswith(suffix): - slice_ = f.get_slice(name) - if suffix in { - "mlp.dense_4h_to_h.weight", - "self_attention.dense.weight", - }: - size = slice_.get_shape()[1] - block_size = size // tp_world_size - start = tp_rank * block_size - stop = (tp_rank + 1) * block_size - tensor = slice_[:, start:stop] - else: - size = slice_.get_shape()[0] - block_size = size // tp_world_size - start = tp_rank * block_size - stop = (tp_rank + 1) * block_size - tensor = slice_[start:stop] - if name.endswith(".weight") and not name.endswith( - "word_embeddings.weight" - ): - tensor = tensor.transpose(1, 0) - handled = True - break - if not handled: - tensor = f.get_tensor(name) + slice_ = f.get_slice(name) - tensor = tensor.contiguous() - - if tp_rank != 0 and ( - name.endswith("self_attention.dense.bias") - or name.endswith("mlp.dense_4h_to_h.bias") - ): - # XXX: Hack for Rowlinear to add the bias only once. - set_tensor(model, full_name, torch.zeros_like(tensor)) + if isinstance(module, TensorParallelColumnLinear): + if param_name == "weight": + size = slice_.get_shape()[0] + block_size = size // tp_world_size + start = tp_rank * block_size + stop = (tp_rank + 1) * block_size + tensor = slice_[start:stop] + tensor = tensor.transpose(1, 0) + else: + size = slice_.get_shape()[0] + block_size = size // tp_world_size + start = tp_rank * block_size + stop = (tp_rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // tp_world_size + start = tp_rank * block_size + stop = (tp_rank + 1) * block_size + tensor = slice_[:, start:stop] + tensor = tensor.transpose(1, 0) + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if tp_rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // tp_world_size + start = tp_rank * block_size + stop = (tp_rank + 1) * block_size + tensor = slice_[start:stop] else: - set_tensor(model, full_name, tensor) - if name == "word_embeddings.weight": - set_tensor(model, "lm_head.weight", tensor) + tensor = slice_[:] if current_tensor.shape != tensor.shape: raise ValueError( f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) + tensor = tensor.contiguous() + module._parameters[param_name] = tensor + if name == "word_embeddings.weight": + model.lm_head._parameters["weight"] = tensor + @contextmanager def init_empty_weights(include_buffers: bool = False): @@ -527,7 +536,7 @@ class BLOOMSharded(BLOOM): self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - filenames = dl_weights(self.rank, model_name) + filenames = dl_weights(self.process_group, model_name) config = AutoConfig.from_pretrained( model_name, slow_but_exact=False, tp_parallel=True @@ -548,7 +557,7 @@ class BLOOMSharded(BLOOM): torch.distributed.barrier(group=self.process_group) # print_rank_0(f"Initialized model") - load(model, filenames, self.rank, self.process_group.size()) + load(model, filenames, self.process_group) self.model = model.to(self.device).eval() self.num_heads = config.n_head // self.process_group.size() torch.distributed.barrier(group=self.process_group)