Changing convert logic.

Should be more robust to shared tensors (ok when using
  `from_pretrained). But forcing us to add new checks in our loading
  code (since the chosen key to keep might be different from
  `transformers`).
This commit is contained in:
Nicolas Patry 2023-06-22 14:52:47 +02:00
parent c9c65ab323
commit 34eadb54e9

View File

@ -1,76 +1,45 @@
import datetime import datetime
import torch import torch
import os
from collections import defaultdict
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
from safetensors.torch import save_file from safetensors.torch import save_file, _remove_duplicate_names, load_file
from safetensors import safe_open from typing import List
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.05:
raise RuntimeError(
f"""The file size different is more than 5%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, sf_file: Path): def convert_file(pt_file: Path, sf_file: Path):
""" """
Convert a pytorch file to a safetensors file Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file.
Unfortunately, this might not respect *transformers* convention.
Forcing us to check for potentially different keys during load when looking
for specific tensors (making tensor sharing explicit).
""" """
logger.info(f"Convert {pt_file} to {sf_file}.") loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
pt_state = torch.load(pt_file, map_location="cpu") metadata = {"format": "pt"}
if "state_dict" in pt_state: for kept_name, to_remove_group in to_removes.items():
pt_state = pt_state["state_dict"] for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
remove_shared_pointers(pt_state) dirname = os.path.dirname(sf_file)
os.makedirs(dirname, exist_ok=True)
# Tensors need to be contiguous save_file(loaded, sf_file, metadata=metadata)
pt_state = {k: v.contiguous() for k, v in pt_state.items()} reloaded = load_file(sf_file)
for k in loaded:
sf_file.parent.mkdir(parents=True, exist_ok=True) pt_tensor = loaded[k]
save_file(pt_state, str(sf_file), metadata={"format": "pt"}) sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
# Check that both files are close in size raise RuntimeError(f"The output tensors do not match for key {k}")
check_file_size(pt_file, sf_file)
# Load safetensors state
for k in pt_state:
pt_tensor = pt_state[k]
with safe_open(sf_file, framework="pt") as f:
sf_tensor = f.get_tensor(k)
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], sf_files: List[Path]): def convert_files(pt_files: List[Path], sf_files: List[Path]):