mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
wip
This commit is contained in:
parent
e8a3ec36c3
commit
2378529c15
@ -62,6 +62,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
filenames,
|
filenames,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
|
config.architectures[0].startswith("GPT2")
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device).to(dtype)
|
self.model = model.eval().to(device).to(dtype)
|
||||||
|
|
||||||
@ -76,6 +77,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
filenames: List[Path],
|
filenames: List[Path],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
transpose: bool
|
||||||
):
|
):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
@ -102,7 +104,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
current_parameter_tensor = None
|
current_parameter_tensor = None
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
if (
|
if transpose and (
|
||||||
"c_fc.weight" in key
|
"c_fc.weight" in key
|
||||||
or "c_proj.weight" in key
|
or "c_proj.weight" in key
|
||||||
or "q_attn.weight" in key
|
or "q_attn.weight" in key
|
||||||
@ -202,6 +204,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
device=device,
|
device=device,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
|
transpose=config.architectures[0].startswith("GPT2")
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(dtype)
|
self.model = model.eval().to(dtype)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
@ -217,6 +220,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
transpose: bool
|
||||||
):
|
):
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
with safe_open(file, framework="pt", device=str(device)) as f:
|
||||||
@ -262,7 +266,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
elif "c_attn" in name:
|
elif "c_attn" in name:
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size =
|
raise ValueError
|
||||||
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
block_size = size // world_size
|
||||||
@ -283,6 +287,16 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
current_parameter_tensor = None
|
current_parameter_tensor = None
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
|
if transpose and (
|
||||||
|
"c_fc.weight" in name
|
||||||
|
or "c_proj.weight" in name
|
||||||
|
or "q_attn.weight" in name
|
||||||
|
or "kv_attn.weight" in name
|
||||||
|
or "c_attn.weight" in name
|
||||||
|
):
|
||||||
|
# Tranpose as we use nn.Linear instead of Conv1D
|
||||||
|
value = value.T
|
||||||
|
|
||||||
if current_parameter_tensor.device == torch.device("meta"):
|
if current_parameter_tensor.device == torch.device("meta"):
|
||||||
# Init qkv
|
# Init qkv
|
||||||
if "c_attn.weight" in final_name:
|
if "c_attn.weight" in final_name:
|
||||||
|
Loading…
Reference in New Issue
Block a user