mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Reworked follwoing https://github.com/huggingface/transformers_bloom_parallel/pull/7
This commit is contained in:
parent
457c9038ff
commit
604b18bec2
@ -9,6 +9,11 @@ from typing import List, Tuple, Optional, Dict
|
|||||||
from huggingface_hub import hf_hub_download, HfApi
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
|
from transformers.models.bloom.parallel_layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from bloom_inference.pb import generate_pb2
|
from bloom_inference.pb import generate_pb2
|
||||||
@ -370,87 +375,91 @@ class BLOOM:
|
|||||||
return generated_texts, next_batch
|
return generated_texts, next_batch
|
||||||
|
|
||||||
|
|
||||||
def dl_weights(rank, model_id):
|
def dl_weights(group, model_id):
|
||||||
|
rank = group.rank()
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
info = api.model_info(model_id)
|
info = api.model_info(model_id)
|
||||||
filenames = set(
|
filenames = [
|
||||||
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
|
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
|
||||||
)
|
]
|
||||||
return [hf_hub_download(model_id, filename=filename) for filename in filenames]
|
# Download the files only on rank 0
|
||||||
|
if rank == 0:
|
||||||
|
# XXX: You might want to try and launch these in a multiprocessing.Pool to download the files faster.
|
||||||
|
[
|
||||||
|
hf_hub_download(model_id, filename=filename, local_files_only=True)
|
||||||
|
for filename in filenames
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
torch.distributed.barrier(group=group)
|
||||||
|
# At this point the files should be in cache
|
||||||
|
return [
|
||||||
|
hf_hub_download(model_id, filename=filename, local_files_only=True)
|
||||||
|
for filename in filenames
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def set_tensor(model, full_name, tensor):
|
def load(model, filenames, group):
|
||||||
splits = full_name.split(".")
|
tp_rank = group.rank()
|
||||||
for split in splits[:-1]:
|
tp_world_size = group.size()
|
||||||
model = getattr(model, split)
|
|
||||||
tensor_name = splits[-1]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
model._parameters[tensor_name] = tensor
|
|
||||||
|
|
||||||
|
|
||||||
def load(model, filenames, tp_rank, tp_world_size):
|
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f:
|
with safe_open(filename, framework="pt", device=f"cuda:{tp_rank}") as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
full_name = f"transformer.{name}"
|
full_name = f"transformer.{name}"
|
||||||
|
|
||||||
|
module_name, param_name = full_name.rsplit(".", 1)
|
||||||
|
module = model.get_submodule(module_name)
|
||||||
current_tensor = parameters[full_name]
|
current_tensor = parameters[full_name]
|
||||||
handled = False
|
|
||||||
for suffix in [
|
|
||||||
"self_attention.dense.weight",
|
|
||||||
"mlp.dense_4h_to_h.weight",
|
|
||||||
"self_attention.query_key_value.weight",
|
|
||||||
"mlp.dense_h_to_4h.weight",
|
|
||||||
"self_attention.query_key_value.bias",
|
|
||||||
"mlp.dense_h_to_4h.bias",
|
|
||||||
"word_embeddings.weight",
|
|
||||||
]:
|
|
||||||
if name.endswith(suffix):
|
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
if suffix in {
|
|
||||||
"mlp.dense_4h_to_h.weight",
|
if isinstance(module, TensorParallelColumnLinear):
|
||||||
"self_attention.dense.weight",
|
if param_name == "weight":
|
||||||
}:
|
size = slice_.get_shape()[0]
|
||||||
size = slice_.get_shape()[1]
|
|
||||||
block_size = size // tp_world_size
|
block_size = size // tp_world_size
|
||||||
start = tp_rank * block_size
|
start = tp_rank * block_size
|
||||||
stop = (tp_rank + 1) * block_size
|
stop = (tp_rank + 1) * block_size
|
||||||
tensor = slice_[:, start:stop]
|
tensor = slice_[start:stop]
|
||||||
|
tensor = tensor.transpose(1, 0)
|
||||||
else:
|
else:
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size = size // tp_world_size
|
block_size = size // tp_world_size
|
||||||
start = tp_rank * block_size
|
start = tp_rank * block_size
|
||||||
stop = (tp_rank + 1) * block_size
|
stop = (tp_rank + 1) * block_size
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
|
elif isinstance(module, TensorParallelRowLinear):
|
||||||
if name.endswith(".weight") and not name.endswith(
|
if param_name == "weight":
|
||||||
"word_embeddings.weight"
|
size = slice_.get_shape()[1]
|
||||||
):
|
block_size = size // tp_world_size
|
||||||
|
start = tp_rank * block_size
|
||||||
|
stop = (tp_rank + 1) * block_size
|
||||||
|
tensor = slice_[:, start:stop]
|
||||||
tensor = tensor.transpose(1, 0)
|
tensor = tensor.transpose(1, 0)
|
||||||
handled = True
|
|
||||||
break
|
|
||||||
if not handled:
|
|
||||||
tensor = f.get_tensor(name)
|
|
||||||
|
|
||||||
tensor = tensor.contiguous()
|
|
||||||
|
|
||||||
if tp_rank != 0 and (
|
|
||||||
name.endswith("self_attention.dense.bias")
|
|
||||||
or name.endswith("mlp.dense_4h_to_h.bias")
|
|
||||||
):
|
|
||||||
# XXX: Hack for Rowlinear to add the bias only once.
|
|
||||||
set_tensor(model, full_name, torch.zeros_like(tensor))
|
|
||||||
else:
|
else:
|
||||||
set_tensor(model, full_name, tensor)
|
tensor = slice_[:]
|
||||||
if name == "word_embeddings.weight":
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||||||
set_tensor(model, "lm_head.weight", tensor)
|
if tp_rank != 0:
|
||||||
|
tensor = torch.zeros_like(tensor)
|
||||||
|
elif isinstance(module, TensorParallelEmbedding):
|
||||||
|
size = slice_.get_shape()[0]
|
||||||
|
block_size = size // tp_world_size
|
||||||
|
start = tp_rank * block_size
|
||||||
|
stop = (tp_rank + 1) * block_size
|
||||||
|
tensor = slice_[start:stop]
|
||||||
|
else:
|
||||||
|
tensor = slice_[:]
|
||||||
|
|
||||||
if current_tensor.shape != tensor.shape:
|
if current_tensor.shape != tensor.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tensor = tensor.contiguous()
|
||||||
|
module._parameters[param_name] = tensor
|
||||||
|
if name == "word_embeddings.weight":
|
||||||
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def init_empty_weights(include_buffers: bool = False):
|
def init_empty_weights(include_buffers: bool = False):
|
||||||
@ -527,7 +536,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
|
|
||||||
filenames = dl_weights(self.rank, model_name)
|
filenames = dl_weights(self.process_group, model_name)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, slow_but_exact=False, tp_parallel=True
|
model_name, slow_but_exact=False, tp_parallel=True
|
||||||
@ -548,7 +557,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
# print_rank_0(f"Initialized model")
|
# print_rank_0(f"Initialized model")
|
||||||
load(model, filenames, self.rank, self.process_group.size())
|
load(model, filenames, self.process_group)
|
||||||
self.model = model.to(self.device).eval()
|
self.model = model.to(self.device).eval()
|
||||||
self.num_heads = config.n_head // self.process_group.size()
|
self.num_heads = config.n_head // self.process_group.size()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
Loading…
Reference in New Issue
Block a user