diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index 0ba90cee..f1e8f167 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -1,15 +1,17 @@ import torch import torch.distributed +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import List, Tuple, Optional, Dict +from huggingface_hub import hf_hub_download, HfApi +from safetensors import safe_open from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -from transformers.modeling_utils import no_init_weights + from bloom_inference.pb import generate_pb2 -from bloom_inference.prepare_weights import prepare_weights, match_suffix from bloom_inference.utils import ( StoppingCriteria, NextTokenChooser, @@ -368,6 +370,149 @@ class BLOOM: return generated_texts, next_batch +def dl_weights(rank, model_id): + api = HfApi() + info = api.model_info(model_id) + filenames = set( + s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors") + ) + return [hf_hub_download(model_id, filename=filename) for filename in filenames] + + +def set_tensor(model, full_name, tensor): + splits = full_name.split(".") + for split in splits[:-1]: + model = getattr(model, split) + tensor_name = splits[-1] + + with torch.no_grad(): + model._parameters[tensor_name] = tensor + + +def load(model, filenames, tp_rank, tp_world_size): + parameters = dict(model.named_parameters()) + for filename in filenames: + with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f: + for name in f.keys(): + full_name = f"transformer.{name}" + current_tensor = parameters[full_name] + handled = False + for suffix in [ + "self_attention.dense.weight", + "mlp.dense_4h_to_h.weight", + "self_attention.query_key_value.weight", + "mlp.dense_h_to_4h.weight", + "self_attention.query_key_value.bias", + "mlp.dense_h_to_4h.bias", + "word_embeddings.weight", + ]: + if name.endswith(suffix): + slice_ = f.get_slice(name) + if suffix in { + "mlp.dense_4h_to_h.weight", + "self_attention.dense.weight", + }: + size = slice_.get_shape()[1] + block_size = size // tp_world_size + start = tp_rank * block_size + stop = (tp_rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + size = slice_.get_shape()[0] + block_size = size // tp_world_size + start = tp_rank * block_size + stop = (tp_rank + 1) * block_size + tensor = slice_[start:stop] + + if name.endswith(".weight") and not name.endswith( + "word_embeddings.weight" + ): + tensor = tensor.transpose(1, 0) + handled = True + break + if not handled: + tensor = f.get_tensor(name) + + tensor = tensor.contiguous() + + if tp_rank != 0 and ( + name.endswith("self_attention.dense.bias") + or name.endswith("mlp.dense_4h_to_h.bias") + ): + # XXX: Hack for Rowlinear to add the bias only once. + set_tensor(model, full_name, torch.zeros_like(tensor)) + else: + set_tensor(model, full_name, tensor) + if name == "word_embeddings.weight": + set_tensor(model, "lm_head.weight", tensor) + + if current_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" + ) + + +@contextmanager +def init_empty_weights(include_buffers: bool = False): + """ + imported from `accelerate` to not depend on it. + """ + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), **kwargs + ) + + def register_empty_buffer(module, name, buffer): + old_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + # Patch tensor creation + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = torch.device("meta") + return fn(*args, **kwargs) + + return wrapper + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr( + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for ( + torch_function_name, + old_torch_function, + ) in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + + class BLOOMSharded(BLOOM): def __init__(self, model_name: str, shard_directory: Path): super(BLOOM, self).__init__() @@ -382,25 +527,7 @@ class BLOOMSharded(BLOOM): self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - # shard state_dict - if self.master: - # TODO @thomasw21 do some caching - shard_state_dict_paths = prepare_weights( - model_name, - shard_directory / "cache", - shard_directory, - tp_world_size=self.world_size, - ) - shard_state_dict_paths = [ - str(path.absolute()) for path in shard_state_dict_paths - ] - else: - shard_state_dict_paths = [None] * self.world_size - - torch.distributed.broadcast_object_list( - shard_state_dict_paths, src=0, group=self.process_group - ) - shard_state_dict_path = shard_state_dict_paths[self.rank] + filenames = dl_weights(self.rank, model_name) config = AutoConfig.from_pretrained( model_name, slow_but_exact=False, tp_parallel=True @@ -415,33 +542,13 @@ class BLOOMSharded(BLOOM): torch.backends.cudnn.allow_tf32 = True with set_default_dtype(dtype): - with no_init_weights(): + with init_empty_weights(): # we can probably set the device to `meta` here? model = AutoModelForCausalLM.from_config(config).to(dtype) torch.distributed.barrier(group=self.process_group) # print_rank_0(f"Initialized model") - state_dict = torch.load(shard_state_dict_path) - # TODO @thomasw21: HACK in order to transpose all weight prior - for key in state_dict.keys(): - do_transpose = False - if not match_suffix(key, "weight"): - continue - - for potential_suffix in [ - "self_attention.query_key_value.weight", - "self_attention.dense.weight", - "dense_h_to_4h.weight", - "dense_4h_to_h.weight", - ]: - if match_suffix(key, potential_suffix): - do_transpose = True - - if do_transpose: - state_dict[key] = state_dict[key].transpose(1, 0).contiguous() - - model.load_state_dict(state_dict, strict=False) - model.tie_weights() + load(model, filenames, self.rank, self.process_group.size()) self.model = model.to(self.device).eval() self.num_heads = config.n_head // self.process_group.size() torch.distributed.barrier(group=self.process_group)