This commit is contained in:
OlivierDehaene 2023-04-06 15:03:13 +02:00
parent e8a3ec36c3
commit 2378529c15

View File

@ -62,6 +62,7 @@ class FlashSantacoder(FlashCausalLM):
filenames,
device,
dtype,
config.architectures[0].startswith("GPT2")
)
self.model = model.eval().to(device).to(dtype)
@ -76,6 +77,7 @@ class FlashSantacoder(FlashCausalLM):
filenames: List[Path],
device: torch.device,
dtype: torch.dtype,
transpose: bool
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
@ -102,7 +104,7 @@ class FlashSantacoder(FlashCausalLM):
current_parameter_tensor = None
if current_parameter_tensor is not None:
if (
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
@ -202,6 +204,7 @@ class FlashSantacoderSharded(FlashSantacoder):
device=device,
rank=self.rank,
world_size=self.world_size,
transpose=config.architectures[0].startswith("GPT2")
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
@ -217,6 +220,7 @@ class FlashSantacoderSharded(FlashSantacoder):
device: torch.device,
rank: int,
world_size: int,
transpose: bool
):
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
@ -262,7 +266,7 @@ class FlashSantacoderSharded(FlashSantacoder):
tensor = slice_[start:stop]
elif "c_attn" in name:
size = slice_.get_shape()[0]
block_size =
raise ValueError
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
@ -283,6 +287,16 @@ class FlashSantacoderSharded(FlashSantacoder):
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in name
or "c_proj.weight" in name
or "q_attn.weight" in name
or "kv_attn.weight" in name
or "c_attn.weight" in name
):
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_name: