mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Reworked follwoing https://github.com/huggingface/transformers_bloom_parallel/pull/7
This commit is contained in:
parent
457c9038ff
commit
604b18bec2
@ -9,6 +9,11 @@ from typing import List, Tuple, Optional, Dict
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers.models.bloom.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
from bloom_inference.pb import generate_pb2
|
||||
@ -370,87 +375,91 @@ class BLOOM:
|
||||
return generated_texts, next_batch
|
||||
|
||||
|
||||
def dl_weights(rank, model_id):
|
||||
def dl_weights(group, model_id):
|
||||
rank = group.rank()
|
||||
api = HfApi()
|
||||
info = api.model_info(model_id)
|
||||
filenames = set(
|
||||
filenames = [
|
||||
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
|
||||
)
|
||||
return [hf_hub_download(model_id, filename=filename) for filename in filenames]
|
||||
]
|
||||
# Download the files only on rank 0
|
||||
if rank == 0:
|
||||
# XXX: You might want to try and launch these in a multiprocessing.Pool to download the files faster.
|
||||
[
|
||||
hf_hub_download(model_id, filename=filename, local_files_only=True)
|
||||
for filename in filenames
|
||||
]
|
||||
else:
|
||||
pass
|
||||
torch.distributed.barrier(group=group)
|
||||
# At this point the files should be in cache
|
||||
return [
|
||||
hf_hub_download(model_id, filename=filename, local_files_only=True)
|
||||
for filename in filenames
|
||||
]
|
||||
|
||||
|
||||
def set_tensor(model, full_name, tensor):
|
||||
splits = full_name.split(".")
|
||||
for split in splits[:-1]:
|
||||
model = getattr(model, split)
|
||||
tensor_name = splits[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
model._parameters[tensor_name] = tensor
|
||||
|
||||
|
||||
def load(model, filenames, tp_rank, tp_world_size):
|
||||
def load(model, filenames, group):
|
||||
tp_rank = group.rank()
|
||||
tp_world_size = group.size()
|
||||
parameters = dict(model.named_parameters())
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f:
|
||||
for name in f.keys():
|
||||
full_name = f"transformer.{name}"
|
||||
|
||||
module_name, param_name = full_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[full_name]
|
||||
handled = False
|
||||
for suffix in [
|
||||
"self_attention.dense.weight",
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.query_key_value.weight",
|
||||
"mlp.dense_h_to_4h.weight",
|
||||
"self_attention.query_key_value.bias",
|
||||
"mlp.dense_h_to_4h.bias",
|
||||
"word_embeddings.weight",
|
||||
]:
|
||||
if name.endswith(suffix):
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
if suffix in {
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
}:
|
||||
size = slice_.get_shape()[1]
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // tp_world_size
|
||||
start = tp_rank * block_size
|
||||
stop = (tp_rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
tensor = slice_[start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // tp_world_size
|
||||
start = tp_rank * block_size
|
||||
stop = (tp_rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
|
||||
if name.endswith(".weight") and not name.endswith(
|
||||
"word_embeddings.weight"
|
||||
):
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // tp_world_size
|
||||
start = tp_rank * block_size
|
||||
stop = (tp_rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
handled = True
|
||||
break
|
||||
if not handled:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
|
||||
if tp_rank != 0 and (
|
||||
name.endswith("self_attention.dense.bias")
|
||||
or name.endswith("mlp.dense_4h_to_h.bias")
|
||||
):
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
set_tensor(model, full_name, torch.zeros_like(tensor))
|
||||
else:
|
||||
set_tensor(model, full_name, tensor)
|
||||
if name == "word_embeddings.weight":
|
||||
set_tensor(model, "lm_head.weight", tensor)
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if tp_rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // tp_world_size
|
||||
start = tp_rank * block_size
|
||||
stop = (tp_rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
|
||||
if current_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_empty_weights(include_buffers: bool = False):
|
||||
@ -527,7 +536,7 @@ class BLOOMSharded(BLOOM):
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
|
||||
filenames = dl_weights(self.rank, model_name)
|
||||
filenames = dl_weights(self.process_group, model_name)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name, slow_but_exact=False, tp_parallel=True
|
||||
@ -548,7 +557,7 @@ class BLOOMSharded(BLOOM):
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
# print_rank_0(f"Initialized model")
|
||||
load(model, filenames, self.rank, self.process_group.size())
|
||||
load(model, filenames, self.process_group)
|
||||
self.model = model.to(self.device).eval()
|
||||
self.num_heads = config.n_head // self.process_group.size()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
Loading…
Reference in New Issue
Block a user