mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
wip
This commit is contained in:
parent
e8a3ec36c3
commit
2378529c15
@ -62,6 +62,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
config.architectures[0].startswith("GPT2")
|
||||
)
|
||||
self.model = model.eval().to(device).to(dtype)
|
||||
|
||||
@ -76,6 +77,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||
filenames: List[Path],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
transpose: bool
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
@ -102,7 +104,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if (
|
||||
if transpose and (
|
||||
"c_fc.weight" in key
|
||||
or "c_proj.weight" in key
|
||||
or "q_attn.weight" in key
|
||||
@ -202,6 +204,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
device=device,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
transpose=config.architectures[0].startswith("GPT2")
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
@ -217,6 +220,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
transpose: bool
|
||||
):
|
||||
for file in filenames:
|
||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
||||
@ -262,7 +266,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
tensor = slice_[start:stop]
|
||||
elif "c_attn" in name:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size =
|
||||
raise ValueError
|
||||
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
@ -283,6 +287,16 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if transpose and (
|
||||
"c_fc.weight" in name
|
||||
or "c_proj.weight" in name
|
||||
or "q_attn.weight" in name
|
||||
or "kv_attn.weight" in name
|
||||
or "c_attn.weight" in name
|
||||
):
|
||||
# Tranpose as we use nn.Linear instead of Conv1D
|
||||
value = value.T
|
||||
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "c_attn.weight" in final_name:
|
||||
|
Loading…
Reference in New Issue
Block a user